summaryrefslogtreecommitdiffhomepage
path: root/apioforum/db.py
blob: 71a52abf9303edda6a15f5f1aeddc93310799be4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import sqlite3
import click
from flask import current_app, g, abort
from flask.cli import with_appcontext
from werkzeug.routing import BaseConverter

from .db_migrations import migrations

def get_db():
    if 'db' not in g:
        g.db = sqlite3.connect(
            current_app.config['DATABASE'],
            detect_types=sqlite3.PARSE_DECLTYPES
        )
        g.db.row_factory = sqlite3.Row
    g.db.execute("PRAGMA foreign_keys = ON;")
    return g.db

def close_db(e=None):
    db = g.pop('db', None)
    if db is not None:
        db.close()

def init_db():
    db = get_db()
    version = db.execute("PRAGMA user_version;").fetchone()[0]
    for i in range(version, len(migrations)):
        db.executescript(migrations[i])
        db.execute(f"PRAGMA user_version = {i+1}")
        db.commit()
        click.echo(f"migration {i}")

@click.command("migrate")
@with_appcontext
def migrate_command():
    """update database scheme etc"""
    init_db()
    click.echo("ok")

def init_app(app):
    app.teardown_appcontext(close_db)
    app.cli.add_command(migrate_command)


class DbWrapper:
    table = None
    primary_key = "id"
    
    # column name -> DbWrapper child class
    # this allows the DbWrapper to automatically fetch the referenced object
    references = {}

    @classmethod
    def get_row(cls, key):
        return get_db().execute(
            f"SELECT * FROM {cls.table} WHERE {cls.primary_key} = ?", (key,))\
                .fetchone()

    @classmethod
    def fetch(cls, key):
        row = cls.get_row(key)
        if row == None: raise KeyError(key)
        return cls(row)

    @classmethod
    def query_some(cls, *args, **kwargs):
        rows = get_db().execute(*args, **kwargs).fetchall()
        for row in rows:
            yield cls(row)

    @classmethod
    def query(cls, *args, **kwargs):
        return(next(cls.query_some(*args, **kwargs)))

    def __init__(self, row):
        self._row = row
        if self.__class__.primary_key:
            self._key = row[self.__class__.primary_key]
        else:
            self._key = None

    def __getattr__(self, attr):
        # special attributes are retrieved from the object itself
        if attr[0] == '_':
            if not attr in self.__dict__:
                raise AttributeError()
            return self.__dict__[attr]

        # changes have been made to the row. fetch it again
        if self._row == None and self._key != None:
            self._row = self.__class__.get_row(self._key)

        # if this column is a reference, fetch the referenced row as an object
        r = self.__class__.references.get(attr, None)
        if r != None and self._row[attr] != None:
            # do not fetch it more than once
            if not attr in self.__dict__:
                self.__dict__[attr] = r.fetch(self._row[attr])
            return self.__dict__[attr]

        try:
            return self._row[attr]
        except IndexError as i:
            try:
                return self.__dict__[attr]
            except KeyError:
                raise AttributeError(attr) from i

    def __setattr__(self, attr, value):
        # special attributes are set on the object itself
        if attr[0] == '_':
            self.__dict__[attr] = value
            return

        if not self.__class__.primary_key:
            raise(RuntimeError('cannot set attributes on this object'))

        cls = self.__class__

        if not isinstance(value, DbWrapper):
            v = value
        else:
            v = value._key

        get_db().execute(
            f"UPDATE {cls.table} SET {attr} = ? WHERE {cls.primary_key} = ?",
                (v, self._key))

        # the fetched row is invalidated.
        # avoid extra queries by querying again only if attributes are accessed
        self._row = None

    def __eq__(self, other):
        if self.__class__.primary_key:
            if isinstance(other, self.__class__):
                # rows with keys are equivalent if their keys are
                return self.__class__.table == other.__class__.table\
                        and self._key == other._key
            else:
                # a row can be compared with its key
                return self._key == other
        else:
            return self._row == other._row

    def __conform__(self, protocol):
        # if used in a database query, convert to database key
        if protocol is sqlite3.PrepareProtocol:
            return self._key

# flask path converter
class DbConverter(BaseConverter):
    # associate names with DbWrapper classes
    db_classes = {}

    def __init__(self, m, db_class, abort=True):
        super(DbConverter, self).__init__(m)
        self.db_class = self.__class__.db_classes[db_class]
        self.abort = abort
    
    def to_python(self, value):
        try:
            return self.db_class.fetch(value)
        except KeyError:
            if self.abort:
                abort(404)
            else:
                return None

    def to_url(self, value):
        if isinstance(value, self.db_class):
            return str(value._key)
        else:
            return str(value)