diff options
Diffstat (limited to 'apioforum/thread.py')
-rw-r--r-- | apioforum/thread.py | 61 |
1 files changed, 32 insertions, 29 deletions
diff --git a/apioforum/thread.py b/apioforum/thread.py index 455530e..bf65a54 100644 --- a/apioforum/thread.py +++ b/apioforum/thread.py @@ -6,7 +6,7 @@ from flask import ( Blueprint, render_template, abort, request, g, redirect, url_for, flash, jsonify ) -from .db import get_db, DbWrapper +from .db import get_db, DbWrapper, DbConverter from .roles import has_permission, requires_permission from .forum import Forum, Tag from .user import User @@ -16,25 +16,32 @@ bp = Blueprint("thread", __name__, url_prefix="/thread") class Poll(DbWrapper): table = "polls" - @classmethod - def get_row(cls, key): - db = get_db() - row = db.execute(""" - SELECT polls.*,total_vote_counts.total_votes FROM polls - LEFT OUTER JOIN total_vote_counts ON polls.id = total_vote_counts.poll - WHERE polls.id = ?; - """,(key,)).fetchone() - if row == None: - return None - options = db.execute(""" - SELECT poll_options.*, vote_counts.num - FROM poll_options - LEFT OUTER JOIN vote_counts ON poll_options.poll = vote_counts.poll - AND poll_options.option_idx = vote_counts.option_idx - WHERE poll_options.poll = ? - ORDER BY option_idx asc; - """,(key,)).fetchall() - row['options'] = options + def __init__(self, row): + super(Poll, self).__init__(row) + self.__dict__['options'] = list( + PollOption.query_some(""" + SELECT poll_options.*, vote_counts.num + FROM poll_options + LEFT OUTER JOIN vote_counts ON poll_options.poll = vote_counts.poll + AND poll_options.option_idx = vote_counts.option_idx + WHERE poll_options.poll = ? + ORDER BY option_idx asc; + """,(self,))) + + def get_total_votes(self): + t = get_db().execute(""" + SELECT total_votes FROM total_vote_counts + WHERE poll = ?""",(self,)).fetchone() + return t[0] if t != None else 0 + + def recent_vote(self, user): + return Vote.query(""" + SELECT * FROM votes + WHERE poll = ? + AND user = ? + AND current + AND NOT is_retraction; + """,(self,user)) class PollOption(DbWrapper): table = "poll_options" @@ -60,12 +67,15 @@ class Thread(DbWrapper): return Tag.query_some(""" SELECT tags.* FROM tags INNER JOIN thread_tags ON thread_tags.tag = tags.id + WHERE thread_tags.thread = ? ORDER BY tags.id """,(self,)) def get_forum(self): return self.forum +DbConverter.db_classes['thread'] = Thread + class Post(DbWrapper): table = "posts" references = {"thread": Thread, "author": User, "vote": Vote} @@ -74,7 +84,7 @@ class Post(DbWrapper): def post_jump(thread_id, post_id): return url_for("thread.view_thread",thread_id=thread_id)+"#post_"+str(post_id) -@bp.route("/<db(Thread):thread>") +@bp.route("/<db(thread):obj>") @requires_permission("p_view_threads") def view_thread(thread): posts = thread.get_posts() @@ -83,14 +93,7 @@ def view_thread(thread): if g.user is None or thread.poll is None: has_voted = None else: - v = Vote.query_one(""" - SELECT * FROM votes - WHERE poll = ? - AND user = ? - AND current - AND NOT is_retraction; - """,(thread.poll,g.user)) - has_voted = v is not None + has_voted = thread.poll.recent_vote(g.user) is not None return render_template( "view_thread.html", |