From 61d73d6470096e3afa5dd278f8bcacc37e1bf877 Mon Sep 17 00:00:00 2001
From: ubq323 <ubq323>
Date: Fri, 21 May 2021 22:17:33 +0000
Subject: more auth things

---
 apioforum/auth.py                         | 76 ++++++++++++++++++++++++++++---
 apioforum/db.py                           | 22 +++++++++
 apioforum/templates/auth/login.html.j2    |  2 +-
 apioforum/templates/auth/register.html.j2 | 14 ++++++
 apioforum/templates/base.html.j2          | 13 ++++++
 5 files changed, 119 insertions(+), 8 deletions(-)
 create mode 100644 apioforum/templates/auth/register.html.j2

diff --git a/apioforum/auth.py b/apioforum/auth.py
index d19ad57..f558025 100644
--- a/apioforum/auth.py
+++ b/apioforum/auth.py
@@ -1,9 +1,10 @@
 from flask import (
     Blueprint, session, request, url_for, render_template, redirect,
-    flash, 
+    flash, g
 )
+from werkzeug.security import check_password_hash, generate_password_hash
 from .db import get_db
-    
+import functools
 
 bp = Blueprint("auth", __name__, url_prefix="/auth")
 
@@ -14,22 +15,78 @@ def login():
         password = request.form["password"]
         db = get_db()
         err = None
+        user = db.execute(
+            "SELECT password FROM users WHERE username = ?;",(username,)
+        ).fetchone()
         if not username:
-            err = "Username required"
+            err = "username required"
         elif not password:
-            err = "Password required"
-        elif username != "bee" or password != "form":
-            err = "Invalid login"
+            err = "password required"
+        elif user is None or not check_password_hash(user['password'], password):
+            err = "invalid login"
 
         if err is None:
             session.clear()
-            session['user'] = 'bee'
+            session['user'] = username
             return redirect(url_for('auth.cool'))
 
         flash(err)
         
     return render_template("auth/login.html.j2")
 
+@bp.route("/register", methods=("GET","POST"))
+def register():
+    if request.method == "POST":
+        username = request.form["username"]
+        password = request.form["password"]
+        db = get_db()
+        err = None
+        if not username:
+            err = "Username required"
+        elif not password:
+            err = "Password required"
+        elif db.execute(
+            "SELECT 1 FROM users WHERE username = ?;", (username,)
+        ).fetchone() is not None:
+            err = f"User {username} is already registered."
+
+        if err is None:
+            db.execute(
+                "INSERT INTO users (username, password) VALUES (?,?);",
+                (username,generate_password_hash(password))
+            )
+            db.commit()
+            flash("successfully created account")
+            session['user'] = username
+            return redirect(url_for("auth.cool"))
+
+        flash(err)
+            
+    return render_template("auth/register.html.j2")
+
+@bp.route("/logout")
+def logout():
+    session.clear()
+    return redirect(url_for("auth.cool"))
+
+@bp.before_app_request
+def load_user():
+    username = session.get("user")
+    if username is None:
+        g.user = None
+    else:
+        g.user = get_db().execute(
+            "SELECT * FROM users WHERE username = ?;", (username,)
+        ).fetchone()
+
+def login_required(view):
+    @functools.wraps(view)
+    def wrapped(**kwargs):
+        print(g.user)
+        if g.user is None:
+            return redirect(url_for("auth.login"))
+        return view(**kwargs)
+    return wrapped
 
 @bp.route("/cool")
 def cool():
@@ -38,3 +95,8 @@ def cool():
         return "you are not logged in"
     else:
         return f"you are logged in as {user}"
+
+@bp.route("/cooler")
+@login_required
+def cooler():
+    return "bee"
diff --git a/apioforum/db.py b/apioforum/db.py
index 6a45640..a2830fe 100644
--- a/apioforum/db.py
+++ b/apioforum/db.py
@@ -19,6 +19,28 @@ def close_db(e=None):
         db.close()
 
 migrations = [
+"""
+CREATE TABLE users (
+    username TEXT PRIMARY KEY,
+    password TEXT NOT NULL
+);""",
+"""
+CREATE TABLE threads (
+    id INT PRIMARY KEY,
+    title TEXT NOT NULL,
+    creator TEXT NOT NULL REFERENCES users(username),
+    created INT NOT NULL,
+    updated INT NOT NULL
+);
+CREATE TABLE posts (
+    id INT PRIMARY KEY,
+    content TEXT,
+    thread INT NOT NULL REFERENCES threads(id),
+    author TEXT NOT NULL REFERENCES users(username),
+    idx INT NOT NULL
+);
+CREATE INDEX posts_thread_idx ON posts (thread);
+""",
 ]
 
 def init_db():
diff --git a/apioforum/templates/auth/login.html.j2 b/apioforum/templates/auth/login.html.j2
index 5f311cf..c8c67b4 100644
--- a/apioforum/templates/auth/login.html.j2
+++ b/apioforum/templates/auth/login.html.j2
@@ -8,7 +8,7 @@
     <label for="username">Username</label>
     <input name="username" id="username" required>
     <label for="password">Password</label>
-    <input name="password" id="password" required>
+    <input type="password" name="password" id="password" required>
     <input type="submit" value="yes">
 </form>
 {% endblock %}
diff --git a/apioforum/templates/auth/register.html.j2 b/apioforum/templates/auth/register.html.j2
new file mode 100644
index 0000000..f7eab81
--- /dev/null
+++ b/apioforum/templates/auth/register.html.j2
@@ -0,0 +1,14 @@
+{% extends "base.html.j2" %}
+{% block header %}
+    <h1>{% block title %}register{% endblock %}</h1>
+{% endblock %}
+
+{% block content %}
+<form method="post">
+    <label for="username">Username</label>
+    <input name="username" id="username" required>
+    <label for="password">Password</label>
+    <input type="password" name="password" id="password" required>
+    <input type="submit" value="yes">
+</form>
+{% endblock %}
diff --git a/apioforum/templates/base.html.j2 b/apioforum/templates/base.html.j2
index 01339c1..6660686 100644
--- a/apioforum/templates/base.html.j2
+++ b/apioforum/templates/base.html.j2
@@ -6,6 +6,19 @@
         <meta name="viewport" content="width=device-width, initial-scale=1">
     </head>
     <body>
+        <nav>
+            <h1>apioforum</h1>
+            <ul>
+                {% if g.user %}
+                <li>{{ g.user['username'] }}</li>
+                <li><a href="{{ url_for('auth.logout') }}">logout</a></li>
+                {% else %}
+                <li><a href="{{ url_for('auth.login') }}">login</a></li>
+                <li><a href="{{ url_for('auth.register') }}">register</a></li>
+                {% endif %}
+            </ul>
+        </nav>
+           
         {% block header %}{% endblock %}
         {% for msg in get_flashed_messages() %}
             <div class="flash">{{ msg }}</div>
-- 
cgit v1.2.3