Browse Source

SQLAlchemy work in progress... switching to laptop (dirty commit)

Dimitri Korsch 3 years ago
parent
commit
59f83ad835

+ 12 - 27
app.py

@@ -1,35 +1,20 @@
 #!/usr/bin/env python
 
-from json import load
-from os import mkdir, path
+import os
+import json
 
-from pycs.database.Database import Database
 from pycs.frontend.WebServer import WebServer
-from pycs.jobs.JobRunner import JobRunner
-from pycs.util.PipelineCache import PipelineCache
 
-if __name__ == '__main__':
-    # load settings
-    print('- load settings')
-    with open('settings.json', 'r') as file:
-        settings = load(file)
-
-    # create projects folder
-    if not path.exists('projects/'):
-        mkdir('projects/')
+print('- Loading settings')
+with open('settings.json') as file:
+    settings = json.load(file)
 
-    # initialize database
-    print('- load database')
-    database = Database('data.sqlite3')
+# create projects folder
+if not os.path.exists('projects/'):
+    os.mkdir('projects/')
 
-    # start job runner
-    print('- start job runner')
-    jobs = JobRunner()
+# start web server
+server = WebServer(settings)
 
-    # create pipeline cache
-    print('- create pipeline cache')
-    pipelines = PipelineCache(jobs)
-
-    # start web server
-    print('- start web server')
-    web_server = WebServer(settings, database, jobs, pipelines)
+if __name__ == '__main__':
+    server.run()

+ 19 - 0
pycs/__init__.py

@@ -0,0 +1,19 @@
+import json
+
+from pathlib import Path
+
+from flask import Flask
+from flask_migrate import Migrate
+from flask_sqlalchemy import SQLAlchemy
+
+
+print('- Loading settings')
+with open('settings.json') as file:
+    settings = json.load(file)
+
+
+app = Flask(__name__)
+app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{Path.cwd() / settings['database']}"
+app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
+db = SQLAlchemy(app)
+migrate = Migrate(app, db)

+ 22 - 56
pycs/database/Collection.py

@@ -1,70 +1,36 @@
 from contextlib import closing
 from typing import Iterator
 
-from pycs.database.File import File
+from pycs import db
+from pycs.database.base import NamedBaseModel
 
+class Collection(NamedBaseModel):
 
-class Collection:
-    """
-    database class for collections
-    """
+    # table columns
+    project_id = db.Column(
+        db.Integer, db.ForeignKey("project.id", ondelete="CASCADE"), nullable=False)
 
-    def __init__(self, database, row):
-        self.database = database
+    reference = db.Column(
+        db.String, nullable=False)
 
-        self.identifier = row[0]
-        self.project_id = row[1]
-        self.reference = row[2]
-        self.name = row[3]
-        self.description = row[4]
-        self.position = row[5]
-        self.autoselect = row[6] > 0
+    description = db.Column(
+        db.String)
 
-    def set_name(self, name: str):
-        """
-        set this collection's name
+    position = db.Column(
+        db.Integer, nullable=False)
 
-        :param name: new name
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE collections SET name = ? WHERE id = ?', (name, self.identifier))
-            self.name = name
+    autoselect = db.Column(
+        db.Boolean, nullable=False)
 
-    def remove(self) -> None:
-        """
-        remove this collection from the database
 
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM collections WHERE id = ?', [self.identifier])
+    # contraints
+    __table_args__ = (
+        db.UniqueConstraint('project_id', 'reference'),
+    )
 
-    def count_files(self) -> int:
-        """
-        count files associated with this project
-
-        :return: count
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT COUNT(*) FROM files WHERE project = ? AND collection = ?',
-                           (self.project_id, self.identifier))
-            return cursor.fetchone()[0]
+    # relationships to other models
+    files = db.relationship("File", backref="collection", lazy=True)
 
-    def files(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
-        """
-        get an iterator of files associated with this collection
 
-        :param offset: file offset
-        :param limit: file limit
-        :return: iterator of files
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT * FROM files
-                WHERE project = ? AND collection = ?
-                ORDER BY id ASC LIMIT ? OFFSET ?
-                ''', (self.project_id, self.identifier, limit, offset))
-
-            for row in cursor:
-                yield File(self.database, row)
+    def count_files(self) -> int:
+        return self.files.count()

+ 11 - 130
pycs/database/Database.py

@@ -3,6 +3,7 @@ from contextlib import closing
 from time import time
 from typing import Optional, Iterator
 
+from pycs import db
 from pycs.database.Collection import Collection
 from pycs.database.File import File
 from pycs.database.LabelProvider import LabelProvider
@@ -18,134 +19,21 @@ class Database:
     opens an sqlite database and allows to access several objects
     """
 
-    def __init__(self, path: str = ':memory:', initialization=True, discovery=True):
-        """
-        opens or creates a given sqlite database and creates all required tables
-
-        :param path: path to sqlite database
-        """
-        # save properties
-        self.path = path
-
-        # initialize database connection
-        self.con = sqlite3.connect(path)
-        self.con.execute("PRAGMA foreign_keys = ON")
-
-        if initialization:
-            # create tables
-            with self, closing(self.con.cursor()) as cursor:
-                cursor.execute('''
-                    CREATE TABLE IF NOT EXISTS models (
-                        id          INTEGER PRIMARY KEY,
-                        name        TEXT                NOT NULL,
-                        description TEXT,
-                        root_folder TEXT                NOT NULL UNIQUE,
-                        supports    TEXT                NOT NULL
-                    )
-                ''')
-                cursor.execute('''
-                    CREATE TABLE IF NOT EXISTS label_providers (
-                        id                 INTEGER PRIMARY KEY,
-                        name               TEXT                NOT NULL,
-                        description        TEXT,
-                        root_folder        TEXT                NOT NULL,
-                        configuration_file TEXT                NOT NULL,
-                        UNIQUE(root_folder, configuration_file)
-                    )
-                ''')
-
-                cursor.execute('''
-                    CREATE TABLE IF NOT EXISTS projects (
-                        id             INTEGER PRIMARY KEY,
-                        name           TEXT                NOT NULL,
-                        description    TEXT,
-                        created        INTEGER             NOT NULL,
-                        model          INTEGER,
-                        label_provider INTEGER,
-                        root_folder    TEXT                NOT NULL UNIQUE,
-                        external_data  BOOL                NOT NULL,
-                        data_folder    TEXT                NOT NULL,
-                        FOREIGN KEY (model) REFERENCES models(id)
-                            ON UPDATE CASCADE ON DELETE SET NULL,
-                        FOREIGN KEY (label_provider) REFERENCES label_providers(id)
-                            ON UPDATE CASCADE ON DELETE SET NULL
-                    )
-                ''')
-                cursor.execute('''
-                    CREATE TABLE IF NOT EXISTS labels (
-                        id              INTEGER PRIMARY KEY,
-                        project         INTEGER             NOT NULL,
-                        parent          INTEGER,
-                        created         INTEGER             NOT NULL,
-                        reference       TEXT,
-                        name            TEXT                NOT NULL,
-                        hierarchy_level TEXT,
-                        FOREIGN KEY (project) REFERENCES projects(id)
-                            ON UPDATE CASCADE ON DELETE CASCADE,
-                        FOREIGN KEY (parent) REFERENCES labels(id)
-                            ON UPDATE CASCADE ON DELETE SET NULL,
-                        UNIQUE(project, reference)
-                    )
-                ''')
-                cursor.execute('''
-                    CREATE TABLE IF NOT EXISTS collections (
-                        id          INTEGER          PRIMARY KEY,
-                        project     INTEGER NOT NULL,
-                        reference   TEXT    NOT NULL,
-                        name        TEXT    NOT NULL,
-                        description TEXT,
-                        position    INTEGER NOT NULL,
-                        autoselect  INTEGER NOT NULL,
-                        FOREIGN KEY (project) REFERENCES projects(id)
-                            ON UPDATE CASCADE ON DELETE CASCADE,
-                        UNIQUE(project, reference)
-                    )
-                ''')
-                cursor.execute('''
-                    CREATE TABLE IF NOT EXISTS files (
-                        id         INTEGER PRIMARY KEY,
-                        uuid       TEXT                NOT NULL,
-                        project    INTEGER             NOT NULL,
-                        collection INTEGER,
-                        type       TEXT                NOT NULL,
-                        name       TEXT                NOT NULL,
-                        extension  TEXT                NOT NULL,
-                        size       INTEGER             NOT NULL,
-                        created    INTEGER             NOT NULL,
-                        path       TEXT                NOT NULL,
-                        frames     INTEGER,
-                        fps        FLOAT,
-                        FOREIGN KEY (project) REFERENCES projects(id)
-                            ON UPDATE CASCADE ON DELETE CASCADE,
-                        FOREIGN KEY (collection) REFERENCES collections(id)
-                            ON UPDATE CASCADE ON DELETE SET NULL,
-                        UNIQUE(project, path)
-                    )
-                ''')
-                cursor.execute('''
-                    CREATE TABLE IF NOT EXISTS results (
-                        id     INTEGER PRIMARY KEY,
-                        file   INTEGER             NOT NULL,
-                        origin TEXT                NOT NULL,
-                        type   TEXT                NOT NULL,
-                        label  INTEGER,
-                        data   TEXT,
-                        FOREIGN KEY (file) REFERENCES files(id)
-                            ON UPDATE CASCADE ON DELETE CASCADE
-                    )
-                ''')
+    def __init__(self, discovery: bool = True):
+        """
+        wrapper for some DB-related runctions. TODO: remove it!
+
+        """
 
         if discovery:
-            # run discovery modules
-            with self:
-                discover_models(self.con)
-                discover_label_providers(self.con)
+            discover_models()
+            discover_label_providers()
 
     def close(self):
         """
         close database file
         """
-        self.con.close()
+        return
 
     def copy(self):
         """
@@ -155,20 +43,13 @@ class Database:
 
         :return: Database
         """
-        return Database(self.path, initialization=False, discovery=False)
+        return self
 
     def commit(self):
         """
         commit changes
         """
-        self.con.commit()
-
-    def __enter__(self):
-        self.con.__enter__()
-        return self
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        self.con.__exit__(exc_type, exc_val, exc_tb)
+        db.session.commit()
 
     def get_object_by_id(self, table_name: str, identifier: int, cls):
         """

+ 95 - 189
pycs/database/File.py

@@ -1,256 +1,162 @@
-from contextlib import closing
+from __future__ import annotations
+
+import json
 import os
-from json import dumps
-from typing import List
-from typing import Optional
+import typing as T
+
+from datetime import datetime
 
+from pycs import db
+from pycs.database.Collection import Collection
 from pycs.database.Result import Result
+from pycs.database.base import NamedBaseModel
+from pycs.database.util import commit_on_return
 
 
-class File:
+class File(NamedBaseModel):
     """
     database class for files
     """
 
-    def __init__(self, database, row):
-        self.database = database
+    # table columns
+    uuid = db.Column(db.String, nullable=False)
 
-        self.identifier = row[0]
-        self.uuid = row[1]
-        self.project_id = row[2]
-        self.collection_id = row[3]
-        self.type = row[4]
-        self.name = row[5]
-        self.extension = row[6]
-        self.size = row[7]
-        self.created = row[8]
-        self.path = row[9]
-        self.frames = row[10]
-        self.fps = row[11]
+    extension = db.Column(db.String, nullable=False)
 
+    type = db.Column(db.String, nullable=False)
 
-    @property
-    def absolute_path(self):
-        if os.path.isabs(self.path):
-            return self.path
+    size = db.Column(db.String, nullable=False)
 
-        return os.path.join(os.getcwd(), self.path)
+    created = db.Column(db.DateTime, default=datetime.utcnow,
+        index=True, nullable=False)
 
-    def project(self):
-        """
-        get the project associated with this file
+    path = db.Column(db.String, nullable=False)
 
-        :return: project
-        """
-        return self.database.project(self.project_id)
+    frames = db.Column(db.Integer)
 
-    def collection(self):
-        """
-        get the collection associated with this file
+    fps = db.Column(db.Float)
 
-        :return: collection
-        """
-        if self.collection_id is None:
-            return None
+    project_id = db.Column(
+        db.Integer,
+        db.ForeignKey("project.id", ondelete="CASCADE"),
+        nullable=False)
 
-        return self.database.collection(self.collection_id)
+    collection_id = db.Column(
+        db.Integer,
+        db.ForeignKey("collection.id", ondelete="SET NULL"))
+
+    # contraints
+    __table_args__ = (
+        db.UniqueConstraint('project_id', 'path'),
+    )
+
+
+    # relationships to other models
+    results = db.relationship("Result", backref="file", lazy=True)
+
+
+    @property
+    def absolute_path(self):
+        if os.path.isabs(self.path):
+            return self.path
+
+        return os.path.join(os.getcwd(), self.path)
 
-    def set_collection(self, collection_id: Optional[int]):
+    @commit_on_return
+    def set_collection(self, id: T.Optional[int]):
         """
         set this file's collection
 
-        :param collection_id: new collection
+        :param id: new collection id
         :return:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE files SET collection = ? WHERE id = ?',
-                           (collection_id, self.identifier))
-            self.collection_id = collection_id
 
-    def set_collection_by_reference(self, collection_reference: Optional[str]):
+        self.collection_id = id
+
+    @commit_on_return
+    def set_collection_by_reference(self, collection_reference: T.Optional[str]):
         """
         set this file's collection
 
         :param collection_reference: collection reference
         :return:
         """
-        if collection_reference is None:
+        if self.collection_reference is None:
             self.set_collection(None)
-            return
-
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT id FROM collections WHERE reference = ?', [collection_reference])
-            row = cursor.fetchone()
 
-        self.set_collection(row[0] if row is not None else None)
+        collection = Collection.query.filter_by(reference=collection_reference).one()
+        self.collection = collection
 
-    def remove(self) -> None:
+    def _get_another_file(self, *query) -> T.Optional[File]:
         """
-        remove this file from the database
+        get the first file matching the query ordered by descending id
 
-        :return:
+        :return: another file or None
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM files WHERE id = ?', [self.identifier])
+        return File.query.filter(File.project_id == self.project_id, *query)\
+            .order_by(File.id.desc())\
+            .first()
 
-    def previous(self):
+    def next(self) -> T.Optional[File]:
         """
-        get the predecessor of this file
+        get the successor of this file
 
-        :return: another file
+        :return: another file or None
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT * FROM files WHERE id < ? AND project = ? ORDER BY id DESC LIMIT 1
-            ''', (self.identifier, self.project_id))
-            row = cursor.fetchone()
-
-            if row is not None:
-                return File(self.database, row)
+        query = File.id > self.id,
+        return self._get_another_file(*query)
 
-            return None
 
-    def next(self):
+    def previous(self) -> T.Optional[File]:
         """
-        get the successor of this file
+        get the predecessor of this file
 
-        :return: another file
+        :return: another file or None
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                    SELECT * FROM files WHERE id > ? AND project = ? ORDER BY id ASC LIMIT 1
-                ''', (self.identifier, self.project_id))
-            row = cursor.fetchone()
-
-            if row is not None:
-                return File(self.database, row)
+        query = File.id < self.id,
+        return self._get_another_file(*query)
 
-            return None
 
-    def previous_in_collection(self):
+    def next_in_collection(self) -> T.Optional[File]:
         """
         get the predecessor of this file
 
-        :return: another file
+        :return: another file or None
         """
-        with closing(self.database.con.cursor()) as cursor:
-            if self.collection_id is None:
-                cursor.execute('''
-                    SELECT * FROM files
-                    WHERE id < ? AND project = ? AND collection IS NULL
-                    ORDER BY id DESC
-                    LIMIT 1
-                ''', (self.identifier, self.project_id))
-            else:
-                cursor.execute('''
-                    SELECT * FROM files
-                    WHERE id < ? AND project = ? AND collection = ?
-                    ORDER BY id DESC
-                    LIMIT 1
-                ''', (self.identifier, self.project_id, self.collection_id))
-
-            row = cursor.fetchone()
-            if row is not None:
-                return File(self.database, row)
-
-            return None
-
-    def next_in_collection(self):
-        """
-        get the successor of this file
+        query = File.id > self.id, File.collection_id == self.collection_id
+        return self._get_another_file(*query)
 
-        :return: another file
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            if self.collection_id is None:
-                cursor.execute('''
-                    SELECT * FROM files
-                    WHERE id > ? AND project = ? AND collection IS NULL
-                    ORDER BY id ASC
-                    LIMIT 1
-                ''', (self.identifier, self.project_id))
-            else:
-                cursor.execute('''
-                    SELECT * FROM files
-                    WHERE id > ? AND project = ? AND collection = ?
-                    ORDER BY id ASC
-                    LIMIT 1
-                ''', (self.identifier, self.project_id, self.collection_id))
-
-            row = cursor.fetchone()
-            if row is not None:
-                return File(self.database, row)
-
-            return None
-
-    def results(self) -> List[Result]:
-        """
-        get a list of all results associated with this file
 
-        :return: list of results
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM results WHERE file = ?', [self.identifier])
-            return list(map(
-                lambda row: Result(self.database, row),
-                cursor.fetchall()
-            ))
-
-    def result(self, identifier) -> Optional[Result]:
+    def previous_in_collection(self) -> T.Optional[File]:
         """
-        get a specific result using its unique identifier
+        get the predecessor of this file
 
-        :param identifier: unique identifier
-        :return: result
+        :return: another file or None
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT * FROM results WHERE id = ? AND file = ?
-            ''', (identifier, self.identifier))
-            row = cursor.fetchone()
-
-            if row is not None:
-                return Result(self.database, row)
+        query = File.id < self.id, File.collection_id == self.collection_id
+        return self._get_another_file(*query)
 
-            return None
 
-    def create_result(self, origin, result_type, label, data=None):
-        """
-        create a result
+    def result(self, id: int) -> T.Optional[Result]:
+        return self.results.get(id)
 
-        :param origin:
-        :param result_type:
-        :param label:
-        :param data:
-        :return:
-        """
-        if data is not None:
-            data = dumps(data)
 
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                INSERT INTO results (file, origin, type, label, data)
-                VALUES              (   ?,      ?,    ?,     ?,    ?)
-            ''', (self.identifier, origin, result_type, label, data))
+    def create_result(self, origin, result_type, label, data: T.Optional[dict] = None):
+        data = data if data is None else json.dumps(data)
 
-            return self.result(cursor.lastrowid)
+        result = Result.new(commit=True,
+                            file=self,
+                            origin=origin,
+                            type=result_type,
+                            label=label,
+                            data=data)
+        return result
 
-    def remove_results(self, origin='pipeline') -> List[Result]:
-        """
-        remove all results with the specified origin
 
-        :param origin: either 'pipeline' or 'user'
-        :return: list of removed results
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT * FROM results WHERE file = ? AND origin = ?
-            ''', (self.identifier, origin))
+    def remove_results(self, origin='pipeline'):
 
-            results = list(map(lambda row: Result(self.database, row), cursor.fetchall()))
+        results = Result.query.filter(Result.file == self, Result.origin == origin)
 
-            cursor.execute('''
-                DELETE FROM results WHERE file = ? AND origin = ?
-            ''', (self.identifier, origin))
+        results.remove()
 
-            return results
+        return results

+ 39 - 86
pycs/database/Label.py

@@ -1,107 +1,60 @@
+from __future__ import annotations
 from contextlib import closing
+from datetime import datetime
 
+from pycs import db
+from pycs.database.base import NamedBaseModel
+from pycs.database.util import commit_on_return
 
-class Label:
-    """
-    database class for labels
-    """
+def compare_children(start_label: Label, id: int):
+    """ check for cyclic relationships """
 
-    def __init__(self, database, row):
-        self.database = database
+    labels_to_check = [start_label]
 
-        self.identifier = row[0]
-        self.project_id = row[1]
-        self.parent_id = row[2]
-        self.created = row[3]
-        self.reference = row[4]
-        self.name = row[5]
-        self.hierarchy_level = row[6]
+    while labels_to_check:
+        label = labels_to_check.pop(0)
 
-    def project(self):
-        """
-        get the project this label is associated with
+        if label.id == id:
+            return False
 
-        :return: project
-        """
-        return self.database.project(self.project_id)
+        labels_to_check.extend(label.children)
 
-    def set_name(self, name: str):
-        """
-        set this labels name
+    return True
 
-        :param name: new name
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE labels SET name = ? WHERE id = ?', (name, self.identifier))
-            self.name = name
+class Label(NamedBaseModel):
 
-    def set_parent(self, parent_id: int):
-        """
-        set this labels parent
+    id = db.Column(db.Integer, primary_key=True)
+    project_id = db.Column(
+        db.Integer,
+        db.ForeignKey("project.id", ondelete="CASCADE"),
+        nullable=False)
 
-        :param parent_id: parent's id
-        :return:
-        """
+    parent_id = db.Column(
+        db.Integer,
+        db.ForeignKey("label.id", ondelete="SET NULL"))
 
-        # check for cyclic relationships
-        def compare_children(label, identifier):
-            if label.identifier == identifier:
-                return False
+    created = db.Column(db.DateTime, default=datetime.utcnow,
+        index=True, nullable=False)
 
-            for child in label.children():
-                if not compare_children(child, identifier):
-                    return False
+    reference = db.Column(db.String)
 
-            return True
+    # contraints
+    __table_args__ = (
+        db.UniqueConstraint('project_id', 'reference'),
+    )
 
-        if not compare_children(self, parent_id):
-            raise ValueError('parent_id')
-
-        # insert parent id
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE labels SET parent = ? WHERE id = ?',
-                           (parent_id, self.identifier))
-            self.parent_id = parent_id
+    # relationships to other models
+    parent = db.relationship("Label", backref="children", remote_side=[id])
 
-    def remove(self):
+    @commit_on_return
+    def set_parent(self, parent_id: int, commit: bool = True):
         """
-        remove this label from the database
+        set this labels parent
 
+        :param parent_id: parent's id
         :return:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM labels WHERE id = ?', [self.identifier])
-
-    def parent(self):
-        """
-        get this labels parent from the database
-
-        :return: parent or None
-        """
-        if self.parent_id is None:
-            return None
-
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM labels WHERE id = ? AND project = ?',
-                           (self.parent_id, self.project_id))
-            row = cursor.fetchone()
-
-            if row is not None:
-                return Label(self.database, row)
-
-        return None
-
-    def children(self):
-        """
-        get this labels children from the database
+        if not compare_children(self, parent_id):
+            raise ValueError('Cyclic relationship detected!')
 
-        :return: list of children
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM labels WHERE parent = ? AND project = ?',
-                           (self.identifier, self.project_id))
-            return list(map(
-                lambda row: Label(self.database, row),
-                cursor.fetchall()
-            ))
+        self.parent_id = parent_id

+ 7 - 13
pycs/database/LabelProvider.py

@@ -1,22 +1,16 @@
-import json
-from os import path
+from pycs import db
+from pycs.database.base import NamedBaseModel
 
-from pycs.interfaces.LabelProvider import LabelProvider as LabelProviderInterface
-
-
-class LabelProvider:
+class LabelProvider(NamedBaseModel):
     """
     database class for label providers
     """
 
-    def __init__(self, database, row):
-        self.database = database
+    description = db.Column(db.String)
+    root_folder = db.Column(db.String, nullable=False, unique=True)
 
-        self.identifier = row[0]
-        self.name = row[1]
-        self.description = row[2]
-        self.root_folder = row[3]
-        self.configuration_file = row[4]
+    # relationships to other models
+    projects = db.relationship("Project", backref="label_provider", lazy=True)
 
     @property
     def configuration_path(self):

+ 33 - 51
pycs/database/Model.py

@@ -1,56 +1,38 @@
-from contextlib import closing
-from json import loads, dumps
+import json
 
+from pycs import db
+from pycs.database.base import NamedBaseModel
+from pycs.database.util import commit_on_return
 
-class Model:
+class Model(NamedBaseModel):
     """
-    database class for label providers
+    database class for ML Models
     """
 
-    def __init__(self, database, row):
-        self.database = database
-
-        self.identifier = row[0]
-        self.name = row[1]
-        self.description = row[2]
-        self.root_folder = row[3]
-        self.supports = loads(row[4])
-
-    def copy_to(self, name: str, root_folder: str):
-        """
-        copies the models database entry while changing name and root_folder
-
-        :param name: copy name
-        :param root_folder: copy root folder
-        :return: copy
-        """
-        supports = dumps(self.supports)
-
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                INSERT INTO models (name, description, root_folder, supports)
-                VALUES (?, ?, ?, ?)
-                ON CONFLICT (root_folder)
-                DO UPDATE SET name = ?, description = ?, supports = ?
-            ''', (name, self.description, root_folder, supports, name, self.description, supports))
-
-            # lastrowid is 0 if on conflict clause applies.
-            # If this is the case we do an extra query to receive the row id.
-            if cursor.lastrowid > 0:
-                row_id = cursor.lastrowid
-                insert = True
-            else:
-                cursor.execute('SELECT id FROM models WHERE root_folder = ?', [root_folder])
-                row_id = cursor.fetchone()[0]
-                insert = False
-
-        return self.database.model(row_id), insert
-
-    def remove(self):
-        """
-        remove this model from the database
-
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM models WHERE id = ?', [self.identifier])
+    description = db.Column(db.String)
+    root_folder = db.Column(db.String, nullable=False, unique=True)
+    supports_encoded = db.Column(db.String, nullable=False)
+
+    # relationships to other models
+    projects = db.relationship("Project", backref="model", lazy=True)
+
+    @property
+    def supports(self):
+        return json.loads(self.supports_encoded)
+
+
+    @commit_on_return
+    def copy_to(self, new_name: str, new_root_folder: str):
+
+        model = Model.query.get(root_folder=new_root_folder)
+        is_new = False
+
+        if model is None:
+            model = Model.new(root_folder=new_root_folder)
+            is_new = True
+
+        model.name = name
+        model.description = self.description
+        model.supports_encoded = self.supports_encoded
+
+        return model, is_new

+ 148 - 296
pycs/database/Project.py

@@ -1,135 +1,90 @@
+import typing as T
+
 from contextlib import closing
+from datetime import datetime
 from os.path import join
-from time import time
-from typing import List, Optional, Tuple, Iterator, Union
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from pycs import db
+from pycs.database.base import NamedBaseModel
 
 from pycs.database.Collection import Collection
 from pycs.database.File import File
 from pycs.database.Label import Label
-from pycs.database.LabelProvider import LabelProvider
-from pycs.database.Model import Model
+from pycs.database.util import commit_on_return
 from pycs.database.util.TreeNodeLabel import TreeNodeLabel
 
+class Project(NamedBaseModel):
+    description = db.Column(db.String)
 
-class Project:
-    """
-    database class for projects
-    """
+    created = db.Column(db.DateTime, default=datetime.utcnow,
+        index=True, nullable=False)
 
-    def __init__(self, database, row):
-        self.database = database
+    model_id = db.Column(
+        db.Integer,
+        db.ForeignKey("model.id", ondelete="SET NULL"))
 
-        self.identifier = row[0]
-        self.name = row[1]
-        self.description = row[2]
-        self.created = row[3]
-        self.model_id = row[4]
-        self.label_provider_id = row[5]
-        self.root_folder = row[6]
-        self.external_data = bool(row[7])
-        self.data_folder = row[8]
+    label_provider_id = db.Column(
+        db.Integer,
+        db.ForeignKey("label_provider.id", ondelete="SET NULL"))
 
-    def model(self) -> Model:
-        """
-        get the model this project is associated with
+    root_folder = db.Column(db.String, nullable=False, unique=True)
 
-        :return: model
-        """
-        return self.database.model(self.model_id)
+    external_data = db.Column(db.Boolean, nullable=False)
 
-    def label_provider(self) -> Optional[LabelProvider]:
-        """
-        get the label provider this project is associated with
+    data_folder = db.Column(db.String, nullable=False)
 
-        :return: label provider
-        """
-        if self.label_provider_id is not None:
-            return self.database.label_provider(self.label_provider_id)
+    # contraints
+    __table_args__ = ()
 
-        return None
+    # relationships to other models
+    files = db.relationship("File", backref="project", lazy=True)
+    labels = db.relationship("Label", backref="project", lazy=True)
+    collections = db.relationship("Collection", backref="project", lazy=True)
 
-    def labels(self) -> List[Label]:
-        """
-        get a list of labels associated with this project
 
-        :return: list of labels
+    def label(self, id: int) -> T.Optional[Label]:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM labels WHERE project = ?', [self.identifier])
-            return list(map(
-                lambda row: Label(self.database, row),
-                cursor.fetchall()
-            ))
+        get a label using its unique identifier
 
-    def label_tree(self) -> List[TreeNodeLabel]:
+        :param identifier: unique identifier
+        :return: label
         """
-        get a list of root labels associated with this project
+        return self.labels.get(id)
 
-        :return: list of labels
+    def file(self, id: int) -> T.Optional[Label]:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                WITH RECURSIVE
-                    tree AS (
-                        SELECT labels.* FROM labels
-                            WHERE project = ? AND parent IS NULL
-                        UNION ALL
-                        SELECT labels.* FROM labels
-                            JOIN tree ON labels.parent = tree.id
-                    )
-                SELECT * FROM tree
-            ''', [self.identifier])
-
-            result = []
-            lookup = {}
-
-            for row in cursor.fetchall():
-                label = TreeNodeLabel(self.database, row)
-                lookup[label.identifier] = label
-
-                if label.parent_id is None:
-                    result.append(label)
-                else:
-                    lookup[label.parent_id].children.append(label)
+        get a file using its unique identifier
 
-            return result
+        :param identifier: unique identifier
+        :return: file
+        """
+        return self.files.get(id)
 
-    def label(self, identifier: int) -> Optional[Label]:
+    def collection(self, id: int) -> T.Optional[Collection]:
         """
-        get a label using its unique identifier
+        get a collection using its unique identifier
 
         :param identifier: unique identifier
-        :return: label
+        :return: collection
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM labels WHERE id = ? AND project = ?',
-                           (identifier, self.identifier))
-            row = cursor.fetchone()
-
-            if row is not None:
-                return Label(self.database, row)
+        return self.collections.get(id)
 
-            return None
-
-    def label_by_reference(self, reference: str) -> Optional[Label]:
+    def collection_by_reference(self, reference: str) -> T.Optional[Collection]:
         """
-        get a label using its reference string
+        get a collection using its unique identifier
 
-        :param reference: reference string
-        :return: label
+        :param identifier: unique identifier
+        :return: collection
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM labels WHERE reference = ? AND project = ?',
-                           (reference, self.identifier))
-            row = cursor.fetchone()
-
-            if row is not None:
-                return Label(self.database, row)
-
-            return None
+        return self.collections.filter_by(reference=reference).one()
 
+    @commit_on_return
     def create_label(self, name: str, reference: str = None,
-                     parent: Union[Label, int, str] = None,
+                     parent_id: int = None,
                      hierarchy_level: str = None) -> Tuple[Optional[Label], bool]:
         """
         create a label for this project. If there is already a label with the same reference
@@ -137,88 +92,59 @@ class Project:
 
         :param name: label name
         :param reference: label reference
-        :param parent: either parent identifier, parent reference string or `Label` object
+        :param parent_id: parent's identifier
         :param hierarchy_level: hierarchy level name
         :return: created or edited label, insert
         """
-        created = int(time())
 
-        if isinstance(parent, str):
-            parent = self.label_by_reference(parent)
-        if isinstance(parent, Label):
-            parent = parent.identifier
+        label = Label.query.get(project=self, reference=reference)
+        is_new = False
 
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                INSERT INTO labels (project, parent, created, reference, name, hierarchy_level)
-                VALUES (?, ?, ?, ?, ?, ?)
-                ON CONFLICT (project, reference) DO
-                UPDATE SET parent = ?, name = ?, hierarchy_level = ?
-            ''', (self.identifier, parent, created, reference, name, hierarchy_level,
-                  parent, name, hierarchy_level))
-
-            # lastrowid is 0 if on conflict clause applies.
-            # If this is the case we do an extra query to receive the row id.
-            if cursor.lastrowid > 0:
-                row_id = cursor.lastrowid
-                insert = True
-            else:
-                cursor.execute('SELECT id FROM labels WHERE project = ? AND reference = ?',
-                               (self.identifier, reference))
-                row_id = cursor.fetchone()[0]
-                insert = False
-
-        return self.label(row_id), insert
-
-    def collections(self) -> List[Collection]:
-        """
-        get a list of collections associated with this project
+        if label is None:
+            label = Label.new(project=self, reference=reference)
+            is_new = True
 
-        :return: list of collections
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM collections WHERE project = ? ORDER BY position ASC',
-                           [self.identifier])
+        label.set_name(name, commit=False)
+        label.set_parent(parent_id, commit=False)
+        label.hierarchy_level = hierarchy_level
 
-            return list(map(
-                lambda row: Collection(self.database, row),
-                cursor.fetchall()
-            ))
+        return label, is_new
 
-    def collection(self, identifier: int) -> Optional[Collection]:
+    def label_tree(self) -> List[TreeNodeLabel]:
         """
-        get a collection using its unique identifier
+        get a list of root labels associated with this project
 
-        :param identifier: unique identifier
-        :return: collection
+        :return: list of labels
         """
+        raise NotImplementedError
         with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM collections WHERE id = ? AND project = ?',
-                           (identifier, self.identifier))
-            row = cursor.fetchone()
-
-            if row is not None:
-                return Collection(self.database, row)
-
-            return None
+            cursor.execute('''
+                WITH RECURSIVE
+                    tree AS (
+                        SELECT labels.* FROM labels
+                            WHERE project = ? AND parent IS NULL
+                        UNION ALL
+                        SELECT labels.* FROM labels
+                            JOIN tree ON labels.parent = tree.id
+                    )
+                SELECT * FROM tree
+            ''', [self.identifier])
 
-    def collection_by_reference(self, reference: str):
-        """
-        get a collection using its reference string
+            result = []
+            lookup = {}
 
-        :param reference: reference string
-        :return: collection
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM collections WHERE reference = ? AND project = ?',
-                           (reference, self.identifier))
-            row = cursor.fetchone()
+            for row in cursor.fetchall():
+                label = TreeNodeLabel(self.database, row)
+                lookup[label.identifier] = label
 
-            if row is not None:
-                return Collection(self.database, row)
+                if label.parent_id is None:
+                    result.append(label)
+                else:
+                    lookup[label.parent_id].children.append(label)
 
-            return None
+            return result
 
+    @commit_on_return
     def create_collection(self,
                           reference: str,
                           name: str,
@@ -236,74 +162,75 @@ class Project:
 
         :return: collection object, insert
         """
-        autoselect = 1 if autoselect else 0
 
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                INSERT INTO collections
-                    (project, reference, name, description, position, autoselect)
-                VALUES (?, ?, ?, ?, ?, ?)
-                ON CONFLICT (project, reference) DO
-                UPDATE SET name = ?, description = ?, position = ?, autoselect = ?
-            ''', (self.identifier, reference, name, description, position, autoselect,
-                  name, description, position, autoselect))
-
-            # lastrowid is 0 if on conflict clause applies.
-            # If this is the case we do an extra query to receive the row id.
-            if cursor.lastrowid > 0:
-                row_id = cursor.lastrowid
-                insert = True
-            else:
-                cursor.execute('SELECT id FROM collections WHERE project = ? AND reference = ?',
-                               (self.identifier, reference))
-                row_id = cursor.fetchone()[0]
-                insert = False
-
-        return self.collection(row_id), insert
-
-    def remove(self) -> None:
-        """
-        remove this project from the database
+        collection = Collection.query.get(project=self, reference=reference)
+        is_new = False
 
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM projects WHERE id = ?', [self.identifier])
+        if collection is None:
+            collection = Collection.new(project=self,
+                                        reference=reference)
+            is_new = True
+
+        collection.name = name
+        collection.description = description
+        collection.position = position
+        collection.autoselect = autoselect
+
+        return collection, is_new
 
-    def set_name(self, name: str) -> None:
+    @commit_on_return
+    def add_file(self, uuid: str, file_type: str, name: str, extension: str, size: int,
+                 filename: str, frames: int = None, fps: float = None) -> T.Tuple[File, bool]:
         """
-        set this projects name
+        add a file to this project
 
-        :param name: new name
-        :return:
+        :param uuid: unique identifier which is used for temporary files
+        :param file_type: file type (either image or video)
+        :param name: file name
+        :param extension: file extension
+        :param size: file size
+        :param filename: actual name in filesystem
+        :param frames: frame count
+        :param fps: frames per second
+        :return: file
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE projects SET name = ? WHERE id = ?', (name, self.identifier))
-            self.name = name
+        path = join(self.data_folder, filename + extension)
+
+        file = File.objects.get(project=self, path=path)
+        is_new = False
+
+        if file is None:
+            file = File.new(uuid=uuid, project=self, path=path)
+            is_new = True
 
-    def set_description(self, description: str) -> None:
+        file.type = file_type
+        file.name = name
+        file.extension = extension
+        file.size = size
+        file.frames = frames
+        file.fps = fps
+
+        return file, is_new
+
+
+    def set_description(self, description: str):
         """
         set this projects description
 
         :param description: new description
         :return:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE projects SET description = ? WHERE id = ?',
-                           (description, self.identifier))
-            self.description = description
-
+        self.description = description
+        self
     def count_files(self) -> int:
         """
         count files associated with this project
 
         :return: count
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT COUNT(*) FROM files WHERE project = ?', [self.identifier])
-            return cursor.fetchone()[0]
+        return self.files.count()
 
-    def files(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
+    def get_files(self, offset: int = 0, limit: int = -1) -> T.Iterator[File]:
         """
         get an iterator of files associated with this project
 
@@ -311,14 +238,7 @@ class Project:
         :param limit: file limit
         :return: iterator of files
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM files WHERE project = ? ORDER BY id ASC LIMIT ? OFFSET ?',
-                           (self.identifier, limit, offset))
-
-            return map(
-                lambda row: File(self.database, row),
-                cursor.fetchall()
-            )
+        return self.files.order_by(File.id.acs()).offset(offset).limit(limit)
 
     def count_files_without_results(self) -> int:
         """
@@ -326,6 +246,8 @@ class Project:
 
         :return: count
         """
+        raise NotImplementedError
+
         with closing(self.database.con.cursor()) as cursor:
             cursor.execute('''
                 SELECT COUNT(*)
@@ -341,6 +263,8 @@ class Project:
 
         :return: list of files
         """
+        raise NotImplementedError
+
         with closing(self.database.con.cursor()) as cursor:
             cursor.execute('''
                 SELECT files.*
@@ -353,91 +277,19 @@ class Project:
             for row in cursor:
                 yield File(self.database, row)
 
-    def count_files_without_collection(self) -> int:
-        """
-        count files associated with this project but with no collection
-
-        :return: count
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT COUNT(*) FROM files WHERE project = ? AND collection IS NULL',
-                           [self.identifier])
-            return cursor.fetchone()[0]
-
-    def files_without_collection(self, offset=0, limit=-1) -> Iterator[File]:
+    def files_without_collection(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
         """
         get an iterator of files without not associated with any collection
 
         :return: list of files
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT * FROM files
-                WHERE files.project = ? AND files.collection IS NULL
-                ORDER BY id ASC
-                LIMIT ? OFFSET ?
-            ''', (self.identifier, limit, offset))
+        return self.get_files(offset, limit).filter(File.collection_id == None)
 
-            for row in cursor:
-                yield File(self.database, row)
 
-    def file(self, identifier) -> Optional[File]:
-        """
-        get a file using its unique identifier
-
-        :param identifier: unique identifier
-        :return: file
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM files WHERE id = ? AND project = ?',
-                           (identifier, self.identifier))
-            row = cursor.fetchone()
-
-            if row is not None:
-                return File(self.database, row)
-
-            return None
-
-    def add_file(self, uuid: str, file_type: str, name: str, extension: str, size: int,
-                 filename: str, frames: int = None, fps: float = None) -> Tuple[File, bool]:
+    def count_files_without_collection(self) -> int:
         """
-        add a file to this project
+        count files associated with this project but with no collection
 
-        :param uuid: unique identifier which is used for temporary files
-        :param file_type: file type (either image or video)
-        :param name: file name
-        :param extension: file extension
-        :param size: file size
-        :param filename: actual name in filesystem
-        :param frames: frame count
-        :param fps: frames per second
-        :return: file
+        :return: count
         """
-        created = int(time())
-        path = join(self.data_folder, filename + extension)
-
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                INSERT INTO files (
-                    uuid, project, type, name, extension, size, created, path, frames, fps
-                )
-                VALUES (  
-                       ?,       ?,    ?,    ?,         ?,    ?,       ?,    ?,      ?,   ?
-                )
-                ON CONFLICT (project, path) DO
-                UPDATE SET type = ?, name = ?, extension = ?, size = ?, frames = ?, fps = ?
-            ''', (uuid, self.identifier, file_type, name, extension, size, created, path, frames,
-                  fps, file_type, name, extension, size, frames, fps))
-
-            # lastrowid is 0 if on conflict clause applies.
-            # If this is the case we do an extra query to receive the row id.
-            if cursor.lastrowid > 0:
-                row_id = cursor.lastrowid
-                insert = True
-            else:
-                cursor.execute('SELECT id FROM files WHERE project = ? AND path = ?',
-                               (self.identifier, path))
-                row_id = cursor.fetchone()[0]
-                insert = False
-
-        return self.file(row_id), insert
+        return self.files_without_collection().count()

+ 29 - 45
pycs/database/Result.py

@@ -1,41 +1,31 @@
-from contextlib import closing
+import typing as T
 
+from contextlib import closing
 from json import dumps, loads
 
+from pycs import db
+from pycs.database.base import BaseModel
+from pycs.database.util import commit_on_return
 
-class Result:
-    """
-    database class for results
-    """
-
-    def __init__(self, database, row):
-        self.database = database
+class Result(BaseModel):
 
-        self.identifier = row[0]
-        self.file_id = row[1]
-        self.origin = row[2]
-        self.type = row[3]
-        self.label = row[4]
-        self.data = loads(row[5]) if row[5] is not None else None
+    file_id = db.Column(
+        db.Integer,
+        db.ForeignKey("file.id", ondelete="CASCADE"),
+        nullable=False)
 
-    def file(self):
-        """
-        getter for the according file
-
-        :return: file object
-        """
+    origin = db.Column(db.String, nullable=False)
+    type = db.Column(db.String, nullable=False)
 
-        return self.database.file(self.file_id)
+    label_id = db.Column(
+        db.Integer,
+        db.ForeignKey("label.id", ondelete="SET NULL"),
+        nullable=True)
 
-    def remove(self):
-        """
-        remove this result from the database
+    data = db.Column(db.String)
 
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM results WHERE id = ?', [self.identifier])
 
+    @commit_on_return
     def set_origin(self, origin: str):
         """
         set this results origin
@@ -43,33 +33,27 @@ class Result:
         :param origin: either 'user' or 'pipeline'
         :return:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE results SET origin = ? WHERE id = ?', (origin, self.identifier))
-            self.origin = origin
+        self.origin = origin
 
+
+    @commit_on_return
     def set_label(self, label: int):
         """
-        set this results label
+        set this results origin
 
-        :param label: label id
+        :param label: label ID
         :return:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE results SET label = ? WHERE id = ?', (label, self.identifier))
-            self.label = label
+        self.label_id = label
 
-    def set_data(self, data: dict):
+    @commit_on_return
+    def set_data(self, data: T.Optional[dict]):
         """
         set this results data object
 
         :param data: data object
         :return:
         """
-        if data is None:
-            data_txt = None
-        else:
-            data_txt = dumps(data)
-
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE results SET data = ? WHERE id = ?', (data_txt, self.identifier))
-            self.data = data
+        data = data if data is None else json.dumps(data)
+
+        self.data = data

+ 40 - 0
pycs/database/base.py

@@ -0,0 +1,40 @@
+
+from pycs import db
+from pycs.database.util import commit_on_return
+
+class BaseModel(db.Model):
+    __abstract__ = True
+
+    id = db.Column(db.Integer, primary_key=True)
+
+
+    def remove(self, commit: bool = True) -> None:
+        """
+        remove this instance from the database
+
+        :return:
+        """
+        db.session.delete(self)
+
+        if commit:
+            self.commit()
+
+    @classmethod
+    def new(cls, commit=False, **kwargs):
+        obj = cls(**kwargs)
+        db.session.add(obj)
+
+        if commit:
+            self.commit()
+
+    def commit(self):
+        db.session.commit()
+
+class NamedBaseModel(BaseModel):
+    __abstract__ = True
+
+    name = db.Column(db.String, nullable=False)
+
+    @commit_on_return
+    def set_name(self, name: str):
+        self.name = name

+ 18 - 0
pycs/database/util/__init__.py

@@ -0,0 +1,18 @@
+from functools import wraps
+
+from pycs import db
+
+def commit_on_return(method):
+
+	@warps(method)
+	def inner(self, *args, commit: bool = True, **kwargs):
+
+		res = method(self, *args, **kwargs)
+
+		if commit:
+			db.session.commit()
+
+		return res
+
+	return inner
+

+ 101 - 88
pycs/frontend/WebServer.py

@@ -1,11 +1,16 @@
+import eventlet
+import os
+import socketio
+
 from glob import glob
-from os import path, getcwd
+from os import getcwd
+from os import path
 from os.path import exists
 
-import eventlet
-import socketio
-from flask import Flask, send_from_directory
+from flask import Flask
+from flask import send_from_directory
 
+from pycs import app
 from pycs.database.Database import Database
 from pycs.frontend.endpoints.ListJobs import ListJobs
 from pycs.frontend.endpoints.ListLabelProviders import ListLabelProviders
@@ -58,9 +63,13 @@ class WebServer:
 
     # pylint: disable=line-too-long
     # pylint: disable=too-many-statements
-    def __init__(self, settings: dict, database: Database, jobs: JobRunner, pipelines: PipelineCache):
+    def __init__(self, app, settings: dict):
+        self.database = Database()
+
+        PRODUCTION = os.path.exists('webui/index.html')
+
         # initialize web server
-        if exists('webui/index.html'):
+        if PRODUCTION:
             print('production build')
 
             # find static files and folders
@@ -82,11 +91,10 @@ class WebServer:
             else:
                 self.__sio = socketio.Server(async_mode='eventlet')
 
-            self.__flask = Flask(__name__)
-            self.__app = socketio.WSGIApp(self.__sio, self.__flask, static_files=static_files)
+            self.__app = socketio.WSGIApp(self.__sio, app, static_files=static_files)
 
             # overwrite root path to serve index.html
-            @self.__flask.route('/', methods=['GET'])
+            @app.route('/', methods=['GET'])
             def index():
                 # pylint: disable=unused-variable
                 return send_from_directory(path.join(getcwd(), 'webui'), 'index.html')
@@ -96,18 +104,17 @@ class WebServer:
 
             # create service objects
             self.__sio = socketio.Server(cors_allowed_origins='*', async_mode='eventlet')
-            self.__flask = Flask(__name__)
-            self.__app = socketio.WSGIApp(self.__sio, self.__flask)
+            self.__app = socketio.WSGIApp(self.__sio, app)
 
             # set access control header to allow requests from Vue.js development server
-            @self.__flask.after_request
+            @app.after_request
             def after_request(response):
                 # pylint: disable=unused-variable
                 response.headers['Access-Control-Allow-Origin'] = '*'
                 return response
 
         # set json encoder so database objects are serialized correctly
-        self.__flask.json_encoder = JSONEncoder
+        app.json_encoder = JSONEncoder
 
         # create notification manager
         notifications = NotificationManager(self.__sio)
@@ -118,183 +125,189 @@ class WebServer:
         jobs.on_finish(notifications.edit_job)
         jobs.on_remove(notifications.remove_job)
 
+        self.define_routes(jobs, notifications, pipelines)
+
+
+    def define_routes(self, jobs, notifications, pipelines):
+
         # additional
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/folder',
             view_func=FolderInformation.as_view('folder_information')
         )
 
         # jobs
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/jobs',
             view_func=ListJobs.as_view('list_jobs', jobs)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/jobs/<identifier>/remove',
             view_func=RemoveJob.as_view('remove_job', jobs)
         )
 
         # models
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/models',
-            view_func=ListModels.as_view('list_models', database)
+            view_func=ListModels.as_view('list_models', self.database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:identifier>/model',
-            view_func=GetProjectModel.as_view('get_project_model', database)
+            view_func=GetProjectModel.as_view('get_project_model', self.database)
         )
 
         # labels
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/label_providers',
-            view_func=ListLabelProviders.as_view('label_providers', database)
+            view_func=ListLabelProviders.as_view('label_providers', self.database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:identifier>/labels',
-            view_func=ListLabels.as_view('list_labels', database)
+            view_func=ListLabels.as_view('list_labels', self.database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:identifier>/labels/tree',
-            view_func=ListLabelTree.as_view('list_label_tree', database)
+            view_func=ListLabelTree.as_view('list_label_tree', self.database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:identifier>/labels',
-            view_func=CreateLabel.as_view('create_label', database, notifications)
+            view_func=CreateLabel.as_view('create_label', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/remove',
-            view_func=RemoveLabel.as_view('remove_label', database, notifications)
+            view_func=RemoveLabel.as_view('remove_label', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/name',
-            view_func=EditLabelName.as_view('edit_label_name', database, notifications)
+            view_func=EditLabelName.as_view('edit_label_name', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/parent',
-            view_func=EditLabelParent.as_view('edit_label_parent', database, notifications)
+            view_func=EditLabelParent.as_view('edit_label_parent', self.database, notifications)
         )
 
         # collections
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:project_id>/collections',
-            view_func=ListCollections.as_view('list_collections', database)
+            view_func=ListCollections.as_view('list_collections', self.database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:project_id>/data/<int:collection_id>/<int:start>/<int:length>',
-            view_func=ListFiles.as_view('list_collection_files', database)
+            view_func=ListFiles.as_view('list_collection_files', self.database)
         )
 
         # data
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:identifier>/data',
-            view_func=UploadFile.as_view('upload_file', database, notifications)
+            view_func=UploadFile.as_view('upload_file', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:project_id>/data/<int:start>/<int:length>',
-            view_func=ListFiles.as_view('list_files', database)
+            view_func=ListFiles.as_view('list_files', self.database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/data/<int:identifier>/remove',
-            view_func=RemoveFile.as_view('remove_file', database, notifications)
+            view_func=RemoveFile.as_view('remove_file', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/data/<int:file_id>',
-            view_func=GetFile.as_view('get_file', database)
+            view_func=GetFile.as_view('get_file', self.database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/data/<int:file_id>/<resolution>',
-            view_func=GetResizedFile.as_view('get_resized_file', database)
+            view_func=GetResizedFile.as_view('get_resized_file', self.database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/data/<int:file_id>/<resolution>/<crop_box>',
             view_func=GetCroppedFile.as_view('crop_result', database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/data/<int:file_id>/previous_next',
-            view_func=GetPreviousAndNextFile.as_view('get_previous_and_next_file', database)
+            view_func=GetPreviousAndNextFile.as_view('get_previous_and_next_file', self.database)
         )
 
         # results
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:project_id>/results',
-            view_func=GetProjectResults.as_view('get_project_results', database)
+            view_func=GetProjectResults.as_view('get_project_results', self.database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/data/<int:file_id>/results',
-            view_func=GetResults.as_view('get_results', database)
+            view_func=GetResults.as_view('get_results', self.database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/data/<int:file_id>/results',
-            view_func=CreateResult.as_view('create_result', database, notifications)
+            view_func=CreateResult.as_view('create_result', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/data/<int:file_id>/reset',
-            view_func=ResetResults.as_view('reset_results', database, notifications)
+            view_func=ResetResults.as_view('reset_results', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/results/<int:result_id>/remove',
-            view_func=RemoveResult.as_view('remove_result', database, notifications)
+            view_func=RemoveResult.as_view('remove_result', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/results/<int:result_id>/confirm',
-            view_func=ConfirmResult.as_view('confirm_result', database, notifications)
+            view_func=ConfirmResult.as_view('confirm_result', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/results/<int:result_id>/label',
-            view_func=EditResultLabel.as_view('edit_result_label', database, notifications)
+            view_func=EditResultLabel.as_view('edit_result_label', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/results/<int:result_id>/data',
-            view_func=EditResultData.as_view('edit_result_data', database, notifications)
+            view_func=EditResultData.as_view('edit_result_data', self.database, notifications)
         )
 
         # projects
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects',
-            view_func=ListProjects.as_view('list_projects', database)
+            view_func=ListProjects.as_view('list_projects', self.database)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects',
-            view_func=CreateProject.as_view('create_project', database, notifications, jobs)
+            view_func=CreateProject.as_view('create_project', self.database, notifications, jobs)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:identifier>/label_provider',
-            view_func=ExecuteLabelProvider.as_view('execute_label_provider', database,
+            view_func=ExecuteLabelProvider.as_view('execute_label_provider', self.database,
                                                    notifications, jobs)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:identifier>/external_storage',
-            view_func=ExecuteExternalStorage.as_view('execute_external_storage', database,
+            view_func=ExecuteExternalStorage.as_view('execute_external_storage', self.database,
                                                      notifications, jobs)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:identifier>/remove',
-            view_func=RemoveProject.as_view('remove_project', database, notifications)
+            view_func=RemoveProject.as_view('remove_project', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:identifier>/name',
-            view_func=EditProjectName.as_view('edit_project_name', database, notifications)
+            view_func=EditProjectName.as_view('edit_project_name', self.database, notifications)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:identifier>/description',
-            view_func=EditProjectDescription.as_view('edit_project_description', database,
+            view_func=EditProjectDescription.as_view('edit_project_description', self.database,
                                                      notifications)
         )
 
         # pipelines
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:project_id>/pipelines/fit',
-            view_func=FitModel.as_view('fit_model', database, jobs, pipelines)
+            view_func=FitModel.as_view('fit_model', self.database, jobs, pipelines)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/projects/<int:project_id>/pipelines/predict',
-            view_func=PredictModel.as_view('predict_model', database, notifications, jobs,
+            view_func=PredictModel.as_view('predict_model', self.database, notifications, jobs,
                                            pipelines)
         )
-        self.__flask.add_url_rule(
+        app.add_url_rule(
             '/data/<int:file_id>/predict',
-            view_func=PredictFile.as_view('predict_file', database, notifications, jobs, pipelines)
+            view_func=PredictFile.as_view('predict_file', self.database, notifications, jobs, pipelines)
         )
 
+    def run(self):
         # finally start web server
-        eventlet.wsgi.server(eventlet.listen((settings['host'], settings['port'])), self.__app)
+        eventlet.wsgi.server(eventlet.listen((self.host, self.port)), app)

+ 4 - 1
requirements.txt

@@ -4,7 +4,10 @@ Pillow
 scipy
 eventlet
 flask
-python-socketio
+flask-socketio
+flask-sqlalchemy
+flask-migrate
+# python-socketio
 munch
 scikit-image
 

+ 2 - 1
settings.json

@@ -1,5 +1,6 @@
 {
   "host": "",
   "port": 5000,
-  "allowedOrigins": []
+  "allowedOrigins": [],
+  "database": "data2.sqlite3"
 }