From ccfce769ed9be433a8b0d40910a3398b9db1bd26 Mon Sep 17 00:00:00 2001 From: citrons Date: Wed, 18 Aug 2021 18:36:35 +0000 Subject: view_thread works again --- apioforum/db.py | 24 ++++++++------ apioforum/forum.py | 8 +++-- apioforum/roles.py | 2 ++ apioforum/templates/common.html | 4 +-- apioforum/templates/view_thread.html | 2 +- apioforum/thread.py | 61 +++++++++++++++++++----------------- 6 files changed, 57 insertions(+), 44 deletions(-) diff --git a/apioforum/db.py b/apioforum/db.py index e151cd4..71a52ab 100644 --- a/apioforum/db.py +++ b/apioforum/db.py @@ -4,7 +4,7 @@ from flask import current_app, g, abort from flask.cli import with_appcontext from werkzeug.routing import BaseConverter -from db_migrations import migrations +from .db_migrations import migrations def get_db(): if 'db' not in g: @@ -92,7 +92,7 @@ class DbWrapper: # if this column is a reference, fetch the referenced row as an object r = self.__class__.references.get(attr, None) - if r != None: + if r != None and self._row[attr] != None: # do not fetch it more than once if not attr in self.__dict__: self.__dict__[attr] = r.fetch(self._row[attr]) @@ -100,18 +100,21 @@ class DbWrapper: try: return self._row[attr] - except KeyError as k: - raise AttributeError() from k + except IndexError as i: + try: + return self.__dict__[attr] + except KeyError: + raise AttributeError(attr) from i def __setattr__(self, attr, value): - if not self.__class__.primary_key: - raise(RuntimeError('cannot set attributes on this object')) - # special attributes are set on the object itself if attr[0] == '_': self.__dict__[attr] = value return + if not self.__class__.primary_key: + raise(RuntimeError('cannot set attributes on this object')) + cls = self.__class__ if not isinstance(value, DbWrapper): @@ -119,8 +122,6 @@ class DbWrapper: else: v = value._key - print(f"UPDATE {cls.table} SET {attr} = ? WHERE {cls.primary_key} = ?") - get_db().execute( f"UPDATE {cls.table} SET {attr} = ? WHERE {cls.primary_key} = ?", (v, self._key)) @@ -148,9 +149,12 @@ class DbWrapper: # flask path converter class DbConverter(BaseConverter): + # associate names with DbWrapper classes + db_classes = {} + def __init__(self, m, db_class, abort=True): super(DbConverter, self).__init__(m) - self.db_class = db_class + self.db_class = self.__class__.db_classes[db_class] self.abort = abort def to_python(self, value): diff --git a/apioforum/forum.py b/apioforum/forum.py index 6d3f5cf..fde1305 100644 --- a/apioforum/forum.py +++ b/apioforum/forum.py @@ -7,7 +7,7 @@ from flask import ( g, redirect, url_for, flash, abort ) -from .db import get_db +from .db import get_db, DbWrapper from .mdrender import render from .roles import get_forum_roles,has_permission,is_bureaucrat,get_user_role, permissions as role_permissions from .permissions import is_admin @@ -17,10 +17,11 @@ import functools bp = Blueprint("forum", __name__, url_prefix="/") +forum_references = {} class Forum(DbWrapper): table = "forums" primary_key = "id" - references = {"parent", Forum} + references = forum_references def has_permission(self, user, permission, login_required=True): return(has_permission(self._key, str(user), permission, login_required)) @@ -43,6 +44,9 @@ class Forum(DbWrapper): def get_forum(self): return self +# cannot references Forum inside of itself; workaround +forum_references["parent"] = Forum + class Tag(DbWrapper): table = "tags" references = {"forum": Forum} diff --git a/apioforum/roles.py b/apioforum/roles.py index 0e3c4d3..734c852 100644 --- a/apioforum/roles.py +++ b/apioforum/roles.py @@ -1,6 +1,8 @@ from .db import get_db from .permissions import is_admin +from flask import g +import functools permissions = [ "p_create_threads", diff --git a/apioforum/templates/common.html b/apioforum/templates/common.html index f6b6f29..221044f 100644 --- a/apioforum/templates/common.html +++ b/apioforum/templates/common.html @@ -3,7 +3,7 @@ {%- endmacro %} {% macro post_url(post) -%} - {{url_for('thread.view_thread', thread_id=post.thread)}}#post_{{post.id}} + {{url_for('thread.view_thread', obj=post.thread)}}#post_{{post.id}} {%- endmacro %} {% macro disp_post(post, buttons=False, forum=None, footer=None) %} @@ -107,7 +107,7 @@ {% endmacro %} {% macro vote_meter(poll) %} - {% set total_votes = poll.total_votes %} + {% set total_votes = poll.get_total_votes() %} {% set n = namespace() %} {% set n.runningtotal = 0 %} diff --git a/apioforum/templates/view_thread.html b/apioforum/templates/view_thread.html index 06c110b..6859e07 100644 --- a/apioforum/templates/view_thread.html +++ b/apioforum/templates/view_thread.html @@ -36,7 +36,7 @@ {% for post in posts %} {% if post.vote %} - {% set vote = votes[post.id] %} + {% set vote = post.vote %} {% set option_idx = vote.option_idx %} {# this is bad but it's going to get refactored anyway #} diff --git a/apioforum/thread.py b/apioforum/thread.py index 455530e..bf65a54 100644 --- a/apioforum/thread.py +++ b/apioforum/thread.py @@ -6,7 +6,7 @@ from flask import ( Blueprint, render_template, abort, request, g, redirect, url_for, flash, jsonify ) -from .db import get_db, DbWrapper +from .db import get_db, DbWrapper, DbConverter from .roles import has_permission, requires_permission from .forum import Forum, Tag from .user import User @@ -16,25 +16,32 @@ bp = Blueprint("thread", __name__, url_prefix="/thread") class Poll(DbWrapper): table = "polls" - @classmethod - def get_row(cls, key): - db = get_db() - row = db.execute(""" - SELECT polls.*,total_vote_counts.total_votes FROM polls - LEFT OUTER JOIN total_vote_counts ON polls.id = total_vote_counts.poll - WHERE polls.id = ?; - """,(key,)).fetchone() - if row == None: - return None - options = db.execute(""" - SELECT poll_options.*, vote_counts.num - FROM poll_options - LEFT OUTER JOIN vote_counts ON poll_options.poll = vote_counts.poll - AND poll_options.option_idx = vote_counts.option_idx - WHERE poll_options.poll = ? - ORDER BY option_idx asc; - """,(key,)).fetchall() - row['options'] = options + def __init__(self, row): + super(Poll, self).__init__(row) + self.__dict__['options'] = list( + PollOption.query_some(""" + SELECT poll_options.*, vote_counts.num + FROM poll_options + LEFT OUTER JOIN vote_counts ON poll_options.poll = vote_counts.poll + AND poll_options.option_idx = vote_counts.option_idx + WHERE poll_options.poll = ? + ORDER BY option_idx asc; + """,(self,))) + + def get_total_votes(self): + t = get_db().execute(""" + SELECT total_votes FROM total_vote_counts + WHERE poll = ?""",(self,)).fetchone() + return t[0] if t != None else 0 + + def recent_vote(self, user): + return Vote.query(""" + SELECT * FROM votes + WHERE poll = ? + AND user = ? + AND current + AND NOT is_retraction; + """,(self,user)) class PollOption(DbWrapper): table = "poll_options" @@ -60,12 +67,15 @@ class Thread(DbWrapper): return Tag.query_some(""" SELECT tags.* FROM tags INNER JOIN thread_tags ON thread_tags.tag = tags.id + WHERE thread_tags.thread = ? ORDER BY tags.id """,(self,)) def get_forum(self): return self.forum +DbConverter.db_classes['thread'] = Thread + class Post(DbWrapper): table = "posts" references = {"thread": Thread, "author": User, "vote": Vote} @@ -74,7 +84,7 @@ class Post(DbWrapper): def post_jump(thread_id, post_id): return url_for("thread.view_thread",thread_id=thread_id)+"#post_"+str(post_id) -@bp.route("/") +@bp.route("/") @requires_permission("p_view_threads") def view_thread(thread): posts = thread.get_posts() @@ -83,14 +93,7 @@ def view_thread(thread): if g.user is None or thread.poll is None: has_voted = None else: - v = Vote.query_one(""" - SELECT * FROM votes - WHERE poll = ? - AND user = ? - AND current - AND NOT is_retraction; - """,(thread.poll,g.user)) - has_voted = v is not None + has_voted = thread.poll.recent_vote(g.user) is not None return render_template( "view_thread.html", -- cgit v1.2.3