summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--apioforum/db.py24
-rw-r--r--apioforum/forum.py8
-rw-r--r--apioforum/roles.py2
-rw-r--r--apioforum/templates/common.html4
-rw-r--r--apioforum/templates/view_thread.html2
-rw-r--r--apioforum/thread.py61
6 files changed, 57 insertions, 44 deletions
diff --git a/apioforum/db.py b/apioforum/db.py
index e151cd4..71a52ab 100644
--- a/apioforum/db.py
+++ b/apioforum/db.py
@@ -4,7 +4,7 @@ from flask import current_app, g, abort
from flask.cli import with_appcontext
from werkzeug.routing import BaseConverter
-from db_migrations import migrations
+from .db_migrations import migrations
def get_db():
if 'db' not in g:
@@ -92,7 +92,7 @@ class DbWrapper:
# if this column is a reference, fetch the referenced row as an object
r = self.__class__.references.get(attr, None)
- if r != None:
+ if r != None and self._row[attr] != None:
# do not fetch it more than once
if not attr in self.__dict__:
self.__dict__[attr] = r.fetch(self._row[attr])
@@ -100,18 +100,21 @@ class DbWrapper:
try:
return self._row[attr]
- except KeyError as k:
- raise AttributeError() from k
+ except IndexError as i:
+ try:
+ return self.__dict__[attr]
+ except KeyError:
+ raise AttributeError(attr) from i
def __setattr__(self, attr, value):
- if not self.__class__.primary_key:
- raise(RuntimeError('cannot set attributes on this object'))
-
# special attributes are set on the object itself
if attr[0] == '_':
self.__dict__[attr] = value
return
+ if not self.__class__.primary_key:
+ raise(RuntimeError('cannot set attributes on this object'))
+
cls = self.__class__
if not isinstance(value, DbWrapper):
@@ -119,8 +122,6 @@ class DbWrapper:
else:
v = value._key
- print(f"UPDATE {cls.table} SET {attr} = ? WHERE {cls.primary_key} = ?")
-
get_db().execute(
f"UPDATE {cls.table} SET {attr} = ? WHERE {cls.primary_key} = ?",
(v, self._key))
@@ -148,9 +149,12 @@ class DbWrapper:
# flask path converter
class DbConverter(BaseConverter):
+ # associate names with DbWrapper classes
+ db_classes = {}
+
def __init__(self, m, db_class, abort=True):
super(DbConverter, self).__init__(m)
- self.db_class = db_class
+ self.db_class = self.__class__.db_classes[db_class]
self.abort = abort
def to_python(self, value):
diff --git a/apioforum/forum.py b/apioforum/forum.py
index 6d3f5cf..fde1305 100644
--- a/apioforum/forum.py
+++ b/apioforum/forum.py
@@ -7,7 +7,7 @@ from flask import (
g, redirect, url_for, flash, abort
)
-from .db import get_db
+from .db import get_db, DbWrapper
from .mdrender import render
from .roles import get_forum_roles,has_permission,is_bureaucrat,get_user_role, permissions as role_permissions
from .permissions import is_admin
@@ -17,10 +17,11 @@ import functools
bp = Blueprint("forum", __name__, url_prefix="/")
+forum_references = {}
class Forum(DbWrapper):
table = "forums"
primary_key = "id"
- references = {"parent", Forum}
+ references = forum_references
def has_permission(self, user, permission, login_required=True):
return(has_permission(self._key, str(user), permission, login_required))
@@ -43,6 +44,9 @@ class Forum(DbWrapper):
def get_forum(self):
return self
+# cannot references Forum inside of itself; workaround
+forum_references["parent"] = Forum
+
class Tag(DbWrapper):
table = "tags"
references = {"forum": Forum}
diff --git a/apioforum/roles.py b/apioforum/roles.py
index 0e3c4d3..734c852 100644
--- a/apioforum/roles.py
+++ b/apioforum/roles.py
@@ -1,6 +1,8 @@
from .db import get_db
from .permissions import is_admin
+from flask import g
+import functools
permissions = [
"p_create_threads",
diff --git a/apioforum/templates/common.html b/apioforum/templates/common.html
index f6b6f29..221044f 100644
--- a/apioforum/templates/common.html
+++ b/apioforum/templates/common.html
@@ -3,7 +3,7 @@
{%- endmacro %}
{% macro post_url(post) -%}
- {{url_for('thread.view_thread', thread_id=post.thread)}}#post_{{post.id}}
+ {{url_for('thread.view_thread', obj=post.thread)}}#post_{{post.id}}
{%- endmacro %}
{% macro disp_post(post, buttons=False, forum=None, footer=None) %}
@@ -107,7 +107,7 @@
{% endmacro %}
{% macro vote_meter(poll) %}
- {% set total_votes = poll.total_votes %}
+ {% set total_votes = poll.get_total_votes() %}
{% set n = namespace() %}
{% set n.runningtotal = 0 %}
<svg width="100%" height="15px" xmlns="http://www.w3.org/2000/svg">
diff --git a/apioforum/templates/view_thread.html b/apioforum/templates/view_thread.html
index 06c110b..6859e07 100644
--- a/apioforum/templates/view_thread.html
+++ b/apioforum/templates/view_thread.html
@@ -36,7 +36,7 @@
{% for post in posts %}
{% if post.vote %}
- {% set vote = votes[post.id] %}
+ {% set vote = post.vote %}
{% set option_idx = vote.option_idx %}
{# this is bad but it's going to get refactored anyway #}
diff --git a/apioforum/thread.py b/apioforum/thread.py
index 455530e..bf65a54 100644
--- a/apioforum/thread.py
+++ b/apioforum/thread.py
@@ -6,7 +6,7 @@ from flask import (
Blueprint, render_template, abort, request, g, redirect,
url_for, flash, jsonify
)
-from .db import get_db, DbWrapper
+from .db import get_db, DbWrapper, DbConverter
from .roles import has_permission, requires_permission
from .forum import Forum, Tag
from .user import User
@@ -16,25 +16,32 @@ bp = Blueprint("thread", __name__, url_prefix="/thread")
class Poll(DbWrapper):
table = "polls"
- @classmethod
- def get_row(cls, key):
- db = get_db()
- row = db.execute("""
- SELECT polls.*,total_vote_counts.total_votes FROM polls
- LEFT OUTER JOIN total_vote_counts ON polls.id = total_vote_counts.poll
- WHERE polls.id = ?;
- """,(key,)).fetchone()
- if row == None:
- return None
- options = db.execute("""
- SELECT poll_options.*, vote_counts.num
- FROM poll_options
- LEFT OUTER JOIN vote_counts ON poll_options.poll = vote_counts.poll
- AND poll_options.option_idx = vote_counts.option_idx
- WHERE poll_options.poll = ?
- ORDER BY option_idx asc;
- """,(key,)).fetchall()
- row['options'] = options
+ def __init__(self, row):
+ super(Poll, self).__init__(row)
+ self.__dict__['options'] = list(
+ PollOption.query_some("""
+ SELECT poll_options.*, vote_counts.num
+ FROM poll_options
+ LEFT OUTER JOIN vote_counts ON poll_options.poll = vote_counts.poll
+ AND poll_options.option_idx = vote_counts.option_idx
+ WHERE poll_options.poll = ?
+ ORDER BY option_idx asc;
+ """,(self,)))
+
+ def get_total_votes(self):
+ t = get_db().execute("""
+ SELECT total_votes FROM total_vote_counts
+ WHERE poll = ?""",(self,)).fetchone()
+ return t[0] if t != None else 0
+
+ def recent_vote(self, user):
+ return Vote.query("""
+ SELECT * FROM votes
+ WHERE poll = ?
+ AND user = ?
+ AND current
+ AND NOT is_retraction;
+ """,(self,user))
class PollOption(DbWrapper):
table = "poll_options"
@@ -60,12 +67,15 @@ class Thread(DbWrapper):
return Tag.query_some("""
SELECT tags.* FROM tags
INNER JOIN thread_tags ON thread_tags.tag = tags.id
+ WHERE thread_tags.thread = ?
ORDER BY tags.id
""",(self,))
def get_forum(self):
return self.forum
+DbConverter.db_classes['thread'] = Thread
+
class Post(DbWrapper):
table = "posts"
references = {"thread": Thread, "author": User, "vote": Vote}
@@ -74,7 +84,7 @@ class Post(DbWrapper):
def post_jump(thread_id, post_id):
return url_for("thread.view_thread",thread_id=thread_id)+"#post_"+str(post_id)
-@bp.route("/<db(Thread):thread>")
+@bp.route("/<db(thread):obj>")
@requires_permission("p_view_threads")
def view_thread(thread):
posts = thread.get_posts()
@@ -83,14 +93,7 @@ def view_thread(thread):
if g.user is None or thread.poll is None:
has_voted = None
else:
- v = Vote.query_one("""
- SELECT * FROM votes
- WHERE poll = ?
- AND user = ?
- AND current
- AND NOT is_retraction;
- """,(thread.poll,g.user))
- has_voted = v is not None
+ has_voted = thread.poll.recent_vote(g.user) is not None
return render_template(
"view_thread.html",