diff options
| -rw-r--r-- | apioforum/forum.py | 75 | ||||
| -rw-r--r-- | apioforum/orm.py | 4 | ||||
| -rw-r--r-- | apioforum/thread.py | 49 | 
3 files changed, 32 insertions, 96 deletions
| diff --git a/apioforum/forum.py b/apioforum/forum.py index 93dbddb..e8abe96 100644 --- a/apioforum/forum.py +++ b/apioforum/forum.py @@ -26,6 +26,7 @@ class Forum(DBObj,table="forums"):          return tags +  THREADS_PER_PAGE = 35  bp = Blueprint("forum", __name__, url_prefix="/") @@ -74,6 +75,29 @@ def requires_bureaucrat(f):      return wrapper +def make_tagfilter_clause(request_arg,avail_tags): +    # returns sql clause that matches tags +    # that have the id contained in tagfilter_arg, if not empty +    # also returns the tag in question +    if request_arg == "": +        return "", None + +    try: +        tagid = int(request_arg) +    except ValueError: +        flash(f'invalid tag id "{request_arg}"') +        abort(400) +    else: +        clause = f"AND thread_tags.tag = {tagid}" + +        # now find the tag in question +        for tag in avail_tags: +            if tag['id'] == tagid: +                return clause, tag + +        flash("that tag doesn't exist or isn't available here") +        abort(400) +  @forum_route("",pagination=True)  @requires_permission("p_view_forum", login_required=False)  def view_forum(forum,page=1): @@ -89,31 +113,10 @@ def view_forum(forum,page=1):          abort(400)      avail_tags = forum.avail_tags() +    tagfilter_clause, tagfilter_tag = +        make_tagfilter_clause(request.args.get("tagfilter",""),avail_tags) -    tagfilter = request.args.get("tagfilter",None) -    if tagfilter == "": -        tagfilter = None -    tagfilter_clause = "" -    tagfilter_tag = None -    if tagfilter is not None: -        try: -            tagfilter = int(tagfilter) -        except ValueError: -            flash(f'invalid tag id "{tagfilter}"') -            abort(400) -        else: -            # there is no risk of sql injection because -            # we just checked it is an int -            tagfilter_clause = f"AND thread_tags.tag = {tagfilter}" -            for the_tag in avail_tags: -                if the_tag['id'] == tagfilter: -                    tagfilter_tag = the_tag -                    break -            else: -                flash("that tag doesn't exist or isn't available here") -                abort(400) -                - +    # all threads on this page      threads = Thread.from_row_list(db.execute(          f"""select * from threads              left outer join thread_tags on threads.id = thread_tags.thread @@ -122,30 +125,8 @@ def view_forum(forum,page=1):              limit ? offset ?;           """,(forum.id, THREADS_PER_PAGE, (page-1)*THREADS_PER_PAGE)).fetchall()) -    # i want to preserve this -    #threads = db.execute( -    #    f"""SELECT -    #        threads.id, threads.title, threads.creator, threads.created, -    #        threads.updated, threads.poll, number_of_posts.num_replies, -    #        most_recent_posts.created as mrp_created, -    #        most_recent_posts.author as mrp_author, -    #        most_recent_posts.id as mrp_id, -    #        most_recent_posts.content as mrp_content, -    #        most_recent_posts.deleted as mrp_deleted -    #    FROM threads -    #    INNER JOIN most_recent_posts ON most_recent_posts.thread = threads.id -    #    INNER JOIN number_of_posts ON number_of_posts.thread = threads.id -    #    LEFT OUTER JOIN thread_tags ON threads.id = thread_tags.thread -    #    WHERE threads.forum = ? {tagfilter_clause} -    #    GROUP BY threads.id -    #    ORDER BY {sortby_by} {sortby_dir} -    #    LIMIT ? OFFSET ?; -    #    """,( -    #        forum.id, -    #        THREADS_PER_PAGE, -    #        (page-1)*THREADS_PER_PAGE, -    #    )).fetchall() +    # total number of threads in this forum, after tag filtering (for pagination bar)      num_threads = db.execute(f"""          SELECT count(*) AS count FROM threads          LEFT OUTER JOIN thread_tags ON threads.id = thread_tags.thread diff --git a/apioforum/orm.py b/apioforum/orm.py index b364be1..14747c6 100644 --- a/apioforum/orm.py +++ b/apioforum/orm.py @@ -4,7 +4,7 @@  from .db import get_db  class DBObj: -    def __init_subclass__(cls, /, table, **kwargs): +    def __init_subclass__(cls, *, table, **kwargs):          # DO NOT pass anything with sql special characters in as the table name          super().__init_subclass__(**kwargs)          cls.table_name = table @@ -13,7 +13,7 @@ class DBObj:      def fetch(cls, *, id):          """fetch an object from the database, looked up by id."""          db = get_db() -        # xxx this could be sped up by caching this query maybe instead of +        # XXX this could be sped up by caching this query maybe instead of          # string formatting every time          row = db.execute(f"select * from {cls.table_name} where id = ?",(id,)).fetchone()          if row is None: diff --git a/apioforum/thread.py b/apioforum/thread.py index a5862ba..6d2a6c6 100644 --- a/apioforum/thread.py +++ b/apioforum/thread.py @@ -17,23 +17,6 @@ POSTS_PER_PAGE = 28  class Thread(DBObj,table="threads"):      fields = ["id","title","creator","created","updated","forum","poll"] -    # maybe this should be on Post instead????? -    @staticmethod -    def which_page(post): -        """ return what page of a thread the given post is on - -        assumes post ids within a thread are monotonically increasing, which -        is probably correct -        """ -        db = get_db() -        amt_before = db.execute(""" -            select count(*) as c from posts -            where thread = ? and id < ?""", -            (post.thread,post.id)).fetchone()['c'] - -        page = 1+math.floor(amt_before/POSTS_PER_PAGE) -        return page -      def tags(self):          db = get_db()          tags = db.execute(""" @@ -83,34 +66,6 @@ def thread_route(relative_path, pagination=False, **kwargs):      return decorator -def which_page(post_id,return_thread_id=False): -    # on which page lieth the post in question? -    # forget not that page numbers employeth a system that has a base of 1. -    # the -    # we need impart the knowledgf e into ourselves pertaining to the -    # number of things -    # before the thing -    # yes - -    db = get_db() -    # ASSUMES THAT post ids are consecutive and things -    # this is probably a reasonable assumption - -    thread_id = db.execute('select thread from posts where id = ?',(post_id,)).fetchone()['thread'] -     -    number_of_things_before_the_thing = db.execute('select count(*) as c, thread as t from posts where thread = ? and id < ?;',(thread_id,post_id)).fetchone()['c'] - -     -    page =  1+math.floor(number_of_things_before_the_thing/POSTS_PER_PAGE) -    if return_thread_id: -        return page, thread_id -    else: -        return page - -def post_jump(post_id,*,external=False): -    page,thread_id=which_page(post_id,True) -    return url_for("thread.view_thread",thread_id=thread_id,page=page,_external=external)+"#post_"+str(post_id) -  @thread_route("",pagination=True)  def view_thread(thread,page=1):      if page < 1: @@ -134,8 +89,8 @@ def view_thread(thread,page=1):      num_posts = db.execute("SELECT count(*) as count FROM posts WHERE posts.thread = ?",(thread.id,)).fetchone()['count']      max_pageno = math.ceil(num_posts/POSTS_PER_PAGE) -    tags = db.execute( -    """SELECT tags.* FROM tags +    tags = db.execute(""" +        SELECT tags.* FROM tags          INNER JOIN thread_tags ON thread_tags.tag = tags.id          WHERE thread_tags.thread = ?          ORDER BY tags.id""",(thread.id,)).fetchall() | 
