From b734d34e39d18ccb616fa8142a04ec9399b1dcd9 Mon Sep 17 00:00:00 2001 From: ubq323 Date: Mon, 26 Sep 2022 19:45:06 +0100 Subject: tagfiltering refactoring --- apioforum/forum.py | 75 ++++++++++++++++++++--------------------------------- apioforum/orm.py | 4 +-- apioforum/thread.py | 49 ++-------------------------------- 3 files changed, 32 insertions(+), 96 deletions(-) diff --git a/apioforum/forum.py b/apioforum/forum.py index 93dbddb..e8abe96 100644 --- a/apioforum/forum.py +++ b/apioforum/forum.py @@ -26,6 +26,7 @@ class Forum(DBObj,table="forums"): return tags + THREADS_PER_PAGE = 35 bp = Blueprint("forum", __name__, url_prefix="/") @@ -74,6 +75,29 @@ def requires_bureaucrat(f): return wrapper +def make_tagfilter_clause(request_arg,avail_tags): + # returns sql clause that matches tags + # that have the id contained in tagfilter_arg, if not empty + # also returns the tag in question + if request_arg == "": + return "", None + + try: + tagid = int(request_arg) + except ValueError: + flash(f'invalid tag id "{request_arg}"') + abort(400) + else: + clause = f"AND thread_tags.tag = {tagid}" + + # now find the tag in question + for tag in avail_tags: + if tag['id'] == tagid: + return clause, tag + + flash("that tag doesn't exist or isn't available here") + abort(400) + @forum_route("",pagination=True) @requires_permission("p_view_forum", login_required=False) def view_forum(forum,page=1): @@ -89,31 +113,10 @@ def view_forum(forum,page=1): abort(400) avail_tags = forum.avail_tags() + tagfilter_clause, tagfilter_tag = + make_tagfilter_clause(request.args.get("tagfilter",""),avail_tags) - tagfilter = request.args.get("tagfilter",None) - if tagfilter == "": - tagfilter = None - tagfilter_clause = "" - tagfilter_tag = None - if tagfilter is not None: - try: - tagfilter = int(tagfilter) - except ValueError: - flash(f'invalid tag id "{tagfilter}"') - abort(400) - else: - # there is no risk of sql injection because - # we just checked it is an int - tagfilter_clause = f"AND thread_tags.tag = {tagfilter}" - for the_tag in avail_tags: - if the_tag['id'] == tagfilter: - tagfilter_tag = the_tag - break - else: - flash("that tag doesn't exist or isn't available here") - abort(400) - - + # all threads on this page threads = Thread.from_row_list(db.execute( f"""select * from threads left outer join thread_tags on threads.id = thread_tags.thread @@ -122,30 +125,8 @@ def view_forum(forum,page=1): limit ? offset ?; """,(forum.id, THREADS_PER_PAGE, (page-1)*THREADS_PER_PAGE)).fetchall()) - # i want to preserve this - #threads = db.execute( - # f"""SELECT - # threads.id, threads.title, threads.creator, threads.created, - # threads.updated, threads.poll, number_of_posts.num_replies, - # most_recent_posts.created as mrp_created, - # most_recent_posts.author as mrp_author, - # most_recent_posts.id as mrp_id, - # most_recent_posts.content as mrp_content, - # most_recent_posts.deleted as mrp_deleted - # FROM threads - # INNER JOIN most_recent_posts ON most_recent_posts.thread = threads.id - # INNER JOIN number_of_posts ON number_of_posts.thread = threads.id - # LEFT OUTER JOIN thread_tags ON threads.id = thread_tags.thread - # WHERE threads.forum = ? {tagfilter_clause} - # GROUP BY threads.id - # ORDER BY {sortby_by} {sortby_dir} - # LIMIT ? OFFSET ?; - # """,( - # forum.id, - # THREADS_PER_PAGE, - # (page-1)*THREADS_PER_PAGE, - # )).fetchall() + # total number of threads in this forum, after tag filtering (for pagination bar) num_threads = db.execute(f""" SELECT count(*) AS count FROM threads LEFT OUTER JOIN thread_tags ON threads.id = thread_tags.thread diff --git a/apioforum/orm.py b/apioforum/orm.py index b364be1..14747c6 100644 --- a/apioforum/orm.py +++ b/apioforum/orm.py @@ -4,7 +4,7 @@ from .db import get_db class DBObj: - def __init_subclass__(cls, /, table, **kwargs): + def __init_subclass__(cls, *, table, **kwargs): # DO NOT pass anything with sql special characters in as the table name super().__init_subclass__(**kwargs) cls.table_name = table @@ -13,7 +13,7 @@ class DBObj: def fetch(cls, *, id): """fetch an object from the database, looked up by id.""" db = get_db() - # xxx this could be sped up by caching this query maybe instead of + # XXX this could be sped up by caching this query maybe instead of # string formatting every time row = db.execute(f"select * from {cls.table_name} where id = ?",(id,)).fetchone() if row is None: diff --git a/apioforum/thread.py b/apioforum/thread.py index a5862ba..6d2a6c6 100644 --- a/apioforum/thread.py +++ b/apioforum/thread.py @@ -17,23 +17,6 @@ POSTS_PER_PAGE = 28 class Thread(DBObj,table="threads"): fields = ["id","title","creator","created","updated","forum","poll"] - # maybe this should be on Post instead????? - @staticmethod - def which_page(post): - """ return what page of a thread the given post is on - - assumes post ids within a thread are monotonically increasing, which - is probably correct - """ - db = get_db() - amt_before = db.execute(""" - select count(*) as c from posts - where thread = ? and id < ?""", - (post.thread,post.id)).fetchone()['c'] - - page = 1+math.floor(amt_before/POSTS_PER_PAGE) - return page - def tags(self): db = get_db() tags = db.execute(""" @@ -83,34 +66,6 @@ def thread_route(relative_path, pagination=False, **kwargs): return decorator -def which_page(post_id,return_thread_id=False): - # on which page lieth the post in question? - # forget not that page numbers employeth a system that has a base of 1. - # the - # we need impart the knowledgf e into ourselves pertaining to the - # number of things - # before the thing - # yes - - db = get_db() - # ASSUMES THAT post ids are consecutive and things - # this is probably a reasonable assumption - - thread_id = db.execute('select thread from posts where id = ?',(post_id,)).fetchone()['thread'] - - number_of_things_before_the_thing = db.execute('select count(*) as c, thread as t from posts where thread = ? and id < ?;',(thread_id,post_id)).fetchone()['c'] - - - page = 1+math.floor(number_of_things_before_the_thing/POSTS_PER_PAGE) - if return_thread_id: - return page, thread_id - else: - return page - -def post_jump(post_id,*,external=False): - page,thread_id=which_page(post_id,True) - return url_for("thread.view_thread",thread_id=thread_id,page=page,_external=external)+"#post_"+str(post_id) - @thread_route("",pagination=True) def view_thread(thread,page=1): if page < 1: @@ -134,8 +89,8 @@ def view_thread(thread,page=1): num_posts = db.execute("SELECT count(*) as count FROM posts WHERE posts.thread = ?",(thread.id,)).fetchone()['count'] max_pageno = math.ceil(num_posts/POSTS_PER_PAGE) - tags = db.execute( - """SELECT tags.* FROM tags + tags = db.execute(""" + SELECT tags.* FROM tags INNER JOIN thread_tags ON thread_tags.tag = tags.id WHERE thread_tags.thread = ? ORDER BY tags.id""",(thread.id,)).fetchall() -- cgit v1.2.3