diff options
| -rw-r--r-- | apioforum/db.py | 24 | ||||
| -rw-r--r-- | apioforum/forum.py | 8 | ||||
| -rw-r--r-- | apioforum/roles.py | 2 | ||||
| -rw-r--r-- | apioforum/templates/common.html | 4 | ||||
| -rw-r--r-- | apioforum/templates/view_thread.html | 2 | ||||
| -rw-r--r-- | apioforum/thread.py | 61 | 
6 files changed, 57 insertions, 44 deletions
diff --git a/apioforum/db.py b/apioforum/db.py index e151cd4..71a52ab 100644 --- a/apioforum/db.py +++ b/apioforum/db.py @@ -4,7 +4,7 @@ from flask import current_app, g, abort  from flask.cli import with_appcontext  from werkzeug.routing import BaseConverter -from db_migrations import migrations +from .db_migrations import migrations  def get_db():      if 'db' not in g: @@ -92,7 +92,7 @@ class DbWrapper:          # if this column is a reference, fetch the referenced row as an object          r = self.__class__.references.get(attr, None) -        if r != None: +        if r != None and self._row[attr] != None:              # do not fetch it more than once              if not attr in self.__dict__:                  self.__dict__[attr] = r.fetch(self._row[attr]) @@ -100,18 +100,21 @@ class DbWrapper:          try:              return self._row[attr] -        except KeyError as k: -            raise AttributeError() from k +        except IndexError as i: +            try: +                return self.__dict__[attr] +            except KeyError: +                raise AttributeError(attr) from i      def __setattr__(self, attr, value): -        if not self.__class__.primary_key: -            raise(RuntimeError('cannot set attributes on this object')) -          # special attributes are set on the object itself          if attr[0] == '_':              self.__dict__[attr] = value              return +        if not self.__class__.primary_key: +            raise(RuntimeError('cannot set attributes on this object')) +          cls = self.__class__          if not isinstance(value, DbWrapper): @@ -119,8 +122,6 @@ class DbWrapper:          else:              v = value._key -        print(f"UPDATE {cls.table} SET {attr} = ? WHERE {cls.primary_key} = ?") -          get_db().execute(              f"UPDATE {cls.table} SET {attr} = ? WHERE {cls.primary_key} = ?",                  (v, self._key)) @@ -148,9 +149,12 @@ class DbWrapper:  # flask path converter  class DbConverter(BaseConverter): +    # associate names with DbWrapper classes +    db_classes = {} +      def __init__(self, m, db_class, abort=True):          super(DbConverter, self).__init__(m) -        self.db_class = db_class +        self.db_class = self.__class__.db_classes[db_class]          self.abort = abort      def to_python(self, value): diff --git a/apioforum/forum.py b/apioforum/forum.py index 6d3f5cf..fde1305 100644 --- a/apioforum/forum.py +++ b/apioforum/forum.py @@ -7,7 +7,7 @@ from flask import (      g, redirect, url_for, flash, abort  ) -from .db import get_db +from .db import get_db, DbWrapper  from .mdrender import render  from .roles import get_forum_roles,has_permission,is_bureaucrat,get_user_role, permissions as role_permissions  from .permissions import is_admin @@ -17,10 +17,11 @@ import functools  bp = Blueprint("forum", __name__, url_prefix="/") +forum_references = {}  class Forum(DbWrapper):      table = "forums"      primary_key = "id" -    references = {"parent", Forum} +    references = forum_references      def has_permission(self, user, permission, login_required=True):          return(has_permission(self._key, str(user), permission, login_required)) @@ -43,6 +44,9 @@ class Forum(DbWrapper):      def get_forum(self):          return self +# cannot references Forum inside of itself; workaround +forum_references["parent"] = Forum +  class Tag(DbWrapper):      table = "tags"      references = {"forum": Forum} diff --git a/apioforum/roles.py b/apioforum/roles.py index 0e3c4d3..734c852 100644 --- a/apioforum/roles.py +++ b/apioforum/roles.py @@ -1,6 +1,8 @@  from .db import get_db  from .permissions import is_admin +from flask import g +import functools  permissions = [      "p_create_threads", diff --git a/apioforum/templates/common.html b/apioforum/templates/common.html index f6b6f29..221044f 100644 --- a/apioforum/templates/common.html +++ b/apioforum/templates/common.html @@ -3,7 +3,7 @@  {%- endmacro %}  {% macro post_url(post) -%} -	{{url_for('thread.view_thread', thread_id=post.thread)}}#post_{{post.id}} +	{{url_for('thread.view_thread', obj=post.thread)}}#post_{{post.id}}  {%- endmacro %}  {% macro disp_post(post, buttons=False, forum=None, footer=None) %} @@ -107,7 +107,7 @@  {% endmacro %}  {% macro vote_meter(poll) %} -	{% set total_votes = poll.total_votes %} +	{% set total_votes = poll.get_total_votes() %}  	{% set n = namespace() %}  	{% set n.runningtotal = 0 %}  	<svg width="100%" height="15px" xmlns="http://www.w3.org/2000/svg"> diff --git a/apioforum/templates/view_thread.html b/apioforum/templates/view_thread.html index 06c110b..6859e07 100644 --- a/apioforum/templates/view_thread.html +++ b/apioforum/templates/view_thread.html @@ -36,7 +36,7 @@      {% for post in posts %}          {% if post.vote %} -            {% set vote = votes[post.id] %}         +            {% set vote = post.vote %}                      {% set option_idx = vote.option_idx %}              {# this is bad but it's going to get refactored anyway #} diff --git a/apioforum/thread.py b/apioforum/thread.py index 455530e..bf65a54 100644 --- a/apioforum/thread.py +++ b/apioforum/thread.py @@ -6,7 +6,7 @@ from flask import (      Blueprint, render_template, abort, request, g, redirect,      url_for, flash, jsonify  ) -from .db import get_db, DbWrapper +from .db import get_db, DbWrapper, DbConverter  from .roles import has_permission, requires_permission  from .forum import Forum, Tag  from .user import User @@ -16,25 +16,32 @@ bp = Blueprint("thread", __name__, url_prefix="/thread")  class Poll(DbWrapper):      table = "polls" -    @classmethod -    def get_row(cls, key): -        db = get_db() -        row = db.execute(""" -            SELECT polls.*,total_vote_counts.total_votes FROM polls -            LEFT OUTER JOIN total_vote_counts ON polls.id = total_vote_counts.poll -            WHERE polls.id = ?;                 -            """,(key,)).fetchone() -        if row == None: -            return None -        options = db.execute(""" -            SELECT poll_options.*, vote_counts.num -            FROM poll_options -            LEFT OUTER JOIN vote_counts  ON poll_options.poll = vote_counts.poll -                                        AND poll_options.option_idx = vote_counts.option_idx  -            WHERE poll_options.poll = ? -            ORDER BY option_idx asc; -            """,(key,)).fetchall() -        row['options'] = options +    def __init__(self, row): +        super(Poll, self).__init__(row) +        self.__dict__['options'] = list( +            PollOption.query_some(""" +                SELECT poll_options.*, vote_counts.num +                FROM poll_options +                LEFT OUTER JOIN vote_counts  ON poll_options.poll = vote_counts.poll +                                            AND poll_options.option_idx = vote_counts.option_idx  +                WHERE poll_options.poll = ? +                ORDER BY option_idx asc; +                """,(self,))) + +    def get_total_votes(self): +        t = get_db().execute(""" +            SELECT total_votes FROM total_vote_counts  +            WHERE poll = ?""",(self,)).fetchone() +        return t[0] if t != None else 0 + +    def recent_vote(self, user): +        return Vote.query(""" +            SELECT * FROM votes  +            WHERE poll = ?  +                AND user = ?  +                AND current  +                AND NOT is_retraction; +            """,(self,user))  class PollOption(DbWrapper):      table = "poll_options" @@ -60,12 +67,15 @@ class Thread(DbWrapper):          return Tag.query_some("""              SELECT tags.* FROM tags              INNER JOIN thread_tags ON thread_tags.tag = tags.id +            WHERE thread_tags.thread = ?              ORDER BY tags.id              """,(self,))      def get_forum(self):          return self.forum +DbConverter.db_classes['thread'] = Thread +  class Post(DbWrapper):      table = "posts"      references = {"thread": Thread, "author": User, "vote": Vote} @@ -74,7 +84,7 @@ class Post(DbWrapper):  def post_jump(thread_id, post_id):      return url_for("thread.view_thread",thread_id=thread_id)+"#post_"+str(post_id) -@bp.route("/<db(Thread):thread>") +@bp.route("/<db(thread):obj>")  @requires_permission("p_view_threads")  def view_thread(thread):      posts = thread.get_posts() @@ -83,14 +93,7 @@ def view_thread(thread):      if g.user is None or thread.poll is None:          has_voted = None      else: -        v = Vote.query_one(""" -            SELECT * FROM votes  -            WHERE poll = ?  -                AND user = ?  -                AND current  -                AND NOT is_retraction; -            """,(thread.poll,g.user)) -        has_voted = v is not None +        has_voted = thread.poll.recent_vote(g.user) is not None      return render_template(          "view_thread.html",  | 
