aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorubq323 <ubq323@ubq323.website>2022-09-26 19:45:06 +0100
committerubq323 <ubq323@ubq323.website>2022-09-26 19:45:06 +0100
commitb734d34e39d18ccb616fa8142a04ec9399b1dcd9 (patch)
treea6a26d88301560b90ec75a98304adc9428b7a630
parent2f26d3748f6e39ba953d15b1ca349da5f7660deb (diff)
tagfiltering refactoringrefactor
-rw-r--r--apioforum/forum.py75
-rw-r--r--apioforum/orm.py4
-rw-r--r--apioforum/thread.py49
3 files changed, 32 insertions, 96 deletions
diff --git a/apioforum/forum.py b/apioforum/forum.py
index 93dbddb..e8abe96 100644
--- a/apioforum/forum.py
+++ b/apioforum/forum.py
@@ -26,6 +26,7 @@ class Forum(DBObj,table="forums"):
return tags
+
THREADS_PER_PAGE = 35
bp = Blueprint("forum", __name__, url_prefix="/")
@@ -74,6 +75,29 @@ def requires_bureaucrat(f):
return wrapper
+def make_tagfilter_clause(request_arg,avail_tags):
+ # returns sql clause that matches tags
+ # that have the id contained in tagfilter_arg, if not empty
+ # also returns the tag in question
+ if request_arg == "":
+ return "", None
+
+ try:
+ tagid = int(request_arg)
+ except ValueError:
+ flash(f'invalid tag id "{request_arg}"')
+ abort(400)
+ else:
+ clause = f"AND thread_tags.tag = {tagid}"
+
+ # now find the tag in question
+ for tag in avail_tags:
+ if tag['id'] == tagid:
+ return clause, tag
+
+ flash("that tag doesn't exist or isn't available here")
+ abort(400)
+
@forum_route("",pagination=True)
@requires_permission("p_view_forum", login_required=False)
def view_forum(forum,page=1):
@@ -89,31 +113,10 @@ def view_forum(forum,page=1):
abort(400)
avail_tags = forum.avail_tags()
+ tagfilter_clause, tagfilter_tag =
+ make_tagfilter_clause(request.args.get("tagfilter",""),avail_tags)
- tagfilter = request.args.get("tagfilter",None)
- if tagfilter == "":
- tagfilter = None
- tagfilter_clause = ""
- tagfilter_tag = None
- if tagfilter is not None:
- try:
- tagfilter = int(tagfilter)
- except ValueError:
- flash(f'invalid tag id "{tagfilter}"')
- abort(400)
- else:
- # there is no risk of sql injection because
- # we just checked it is an int
- tagfilter_clause = f"AND thread_tags.tag = {tagfilter}"
- for the_tag in avail_tags:
- if the_tag['id'] == tagfilter:
- tagfilter_tag = the_tag
- break
- else:
- flash("that tag doesn't exist or isn't available here")
- abort(400)
-
-
+ # all threads on this page
threads = Thread.from_row_list(db.execute(
f"""select * from threads
left outer join thread_tags on threads.id = thread_tags.thread
@@ -122,30 +125,8 @@ def view_forum(forum,page=1):
limit ? offset ?;
""",(forum.id, THREADS_PER_PAGE, (page-1)*THREADS_PER_PAGE)).fetchall())
- # i want to preserve this
- #threads = db.execute(
- # f"""SELECT
- # threads.id, threads.title, threads.creator, threads.created,
- # threads.updated, threads.poll, number_of_posts.num_replies,
- # most_recent_posts.created as mrp_created,
- # most_recent_posts.author as mrp_author,
- # most_recent_posts.id as mrp_id,
- # most_recent_posts.content as mrp_content,
- # most_recent_posts.deleted as mrp_deleted
- # FROM threads
- # INNER JOIN most_recent_posts ON most_recent_posts.thread = threads.id
- # INNER JOIN number_of_posts ON number_of_posts.thread = threads.id
- # LEFT OUTER JOIN thread_tags ON threads.id = thread_tags.thread
- # WHERE threads.forum = ? {tagfilter_clause}
- # GROUP BY threads.id
- # ORDER BY {sortby_by} {sortby_dir}
- # LIMIT ? OFFSET ?;
- # """,(
- # forum.id,
- # THREADS_PER_PAGE,
- # (page-1)*THREADS_PER_PAGE,
- # )).fetchall()
+ # total number of threads in this forum, after tag filtering (for pagination bar)
num_threads = db.execute(f"""
SELECT count(*) AS count FROM threads
LEFT OUTER JOIN thread_tags ON threads.id = thread_tags.thread
diff --git a/apioforum/orm.py b/apioforum/orm.py
index b364be1..14747c6 100644
--- a/apioforum/orm.py
+++ b/apioforum/orm.py
@@ -4,7 +4,7 @@
from .db import get_db
class DBObj:
- def __init_subclass__(cls, /, table, **kwargs):
+ def __init_subclass__(cls, *, table, **kwargs):
# DO NOT pass anything with sql special characters in as the table name
super().__init_subclass__(**kwargs)
cls.table_name = table
@@ -13,7 +13,7 @@ class DBObj:
def fetch(cls, *, id):
"""fetch an object from the database, looked up by id."""
db = get_db()
- # xxx this could be sped up by caching this query maybe instead of
+ # XXX this could be sped up by caching this query maybe instead of
# string formatting every time
row = db.execute(f"select * from {cls.table_name} where id = ?",(id,)).fetchone()
if row is None:
diff --git a/apioforum/thread.py b/apioforum/thread.py
index a5862ba..6d2a6c6 100644
--- a/apioforum/thread.py
+++ b/apioforum/thread.py
@@ -17,23 +17,6 @@ POSTS_PER_PAGE = 28
class Thread(DBObj,table="threads"):
fields = ["id","title","creator","created","updated","forum","poll"]
- # maybe this should be on Post instead?????
- @staticmethod
- def which_page(post):
- """ return what page of a thread the given post is on
-
- assumes post ids within a thread are monotonically increasing, which
- is probably correct
- """
- db = get_db()
- amt_before = db.execute("""
- select count(*) as c from posts
- where thread = ? and id < ?""",
- (post.thread,post.id)).fetchone()['c']
-
- page = 1+math.floor(amt_before/POSTS_PER_PAGE)
- return page
-
def tags(self):
db = get_db()
tags = db.execute("""
@@ -83,34 +66,6 @@ def thread_route(relative_path, pagination=False, **kwargs):
return decorator
-def which_page(post_id,return_thread_id=False):
- # on which page lieth the post in question?
- # forget not that page numbers employeth a system that has a base of 1.
- # the
- # we need impart the knowledgf e into ourselves pertaining to the
- # number of things
- # before the thing
- # yes
-
- db = get_db()
- # ASSUMES THAT post ids are consecutive and things
- # this is probably a reasonable assumption
-
- thread_id = db.execute('select thread from posts where id = ?',(post_id,)).fetchone()['thread']
-
- number_of_things_before_the_thing = db.execute('select count(*) as c, thread as t from posts where thread = ? and id < ?;',(thread_id,post_id)).fetchone()['c']
-
-
- page = 1+math.floor(number_of_things_before_the_thing/POSTS_PER_PAGE)
- if return_thread_id:
- return page, thread_id
- else:
- return page
-
-def post_jump(post_id,*,external=False):
- page,thread_id=which_page(post_id,True)
- return url_for("thread.view_thread",thread_id=thread_id,page=page,_external=external)+"#post_"+str(post_id)
-
@thread_route("",pagination=True)
def view_thread(thread,page=1):
if page < 1:
@@ -134,8 +89,8 @@ def view_thread(thread,page=1):
num_posts = db.execute("SELECT count(*) as count FROM posts WHERE posts.thread = ?",(thread.id,)).fetchone()['count']
max_pageno = math.ceil(num_posts/POSTS_PER_PAGE)
- tags = db.execute(
- """SELECT tags.* FROM tags
+ tags = db.execute("""
+ SELECT tags.* FROM tags
INNER JOIN thread_tags ON thread_tags.tag = tags.id
WHERE thread_tags.thread = ?
ORDER BY tags.id""",(thread.id,)).fetchall()