6
0
Эх сурвалжийг харах

reworked DB models to SQLAlchemy

Dimitri Korsch 3 жил өмнө
parent
commit
1ab5050a2f

+ 2 - 7
app.py

@@ -1,19 +1,14 @@
 #!/usr/bin/env python
 import os
-import json
 
+from pycs import app, settings
 from pycs.frontend.WebServer import WebServer
 
-print('- Loading settings')
-with open('settings.json') as file:
-    settings = json.load(file)
-
 # create projects folder
 if not os.path.exists('projects/'):
     os.mkdir('projects/')
 
 # start web server
-server = WebServer(settings)
-
 if __name__ == '__main__':
+    server = WebServer(app, settings)
     server.run()

+ 1 - 0
migrations/README

@@ -0,0 +1 @@
+Generic single-database configuration.

+ 50 - 0
migrations/alembic.ini

@@ -0,0 +1,50 @@
+# A generic, single database configuration.
+
+[alembic]
+# template used to generate migration files
+# file_template = %%(rev)s_%%(slug)s
+
+# set to 'true' to run the environment during
+# the 'revision' command, regardless of autogenerate
+# revision_environment = false
+
+
+# Logging configuration
+[loggers]
+keys = root,sqlalchemy,alembic,flask_migrate
+
+[handlers]
+keys = console
+
+[formatters]
+keys = generic
+
+[logger_root]
+level = WARN
+handlers = console
+qualname =
+
+[logger_sqlalchemy]
+level = WARN
+handlers =
+qualname = sqlalchemy.engine
+
+[logger_alembic]
+level = INFO
+handlers =
+qualname = alembic
+
+[logger_flask_migrate]
+level = INFO
+handlers =
+qualname = flask_migrate
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatter_generic]
+format = %(levelname)-5.5s [%(name)s] %(message)s
+datefmt = %H:%M:%S

+ 91 - 0
migrations/env.py

@@ -0,0 +1,91 @@
+from __future__ import with_statement
+
+import logging
+from logging.config import fileConfig
+
+from flask import current_app
+
+from alembic import context
+
+# this is the Alembic Config object, which provides
+# access to the values within the .ini file in use.
+config = context.config
+
+# Interpret the config file for Python logging.
+# This line sets up loggers basically.
+fileConfig(config.config_file_name)
+logger = logging.getLogger('alembic.env')
+
+# add your model's MetaData object here
+# for 'autogenerate' support
+# from myapp import mymodel
+# target_metadata = mymodel.Base.metadata
+config.set_main_option(
+    'sqlalchemy.url',
+    str(current_app.extensions['migrate'].db.get_engine().url).replace(
+        '%', '%%'))
+target_metadata = current_app.extensions['migrate'].db.metadata
+
+# other values from the config, defined by the needs of env.py,
+# can be acquired:
+# my_important_option = config.get_main_option("my_important_option")
+# ... etc.
+
+
+def run_migrations_offline():
+    """Run migrations in 'offline' mode.
+
+    This configures the context with just a URL
+    and not an Engine, though an Engine is acceptable
+    here as well.  By skipping the Engine creation
+    we don't even need a DBAPI to be available.
+
+    Calls to context.execute() here emit the given string to the
+    script output.
+
+    """
+    url = config.get_main_option("sqlalchemy.url")
+    context.configure(
+        url=url, target_metadata=target_metadata, literal_binds=True
+    )
+
+    with context.begin_transaction():
+        context.run_migrations()
+
+
+def run_migrations_online():
+    """Run migrations in 'online' mode.
+
+    In this scenario we need to create an Engine
+    and associate a connection with the context.
+
+    """
+
+    # this callback is used to prevent an auto-migration from being generated
+    # when there are no changes to the schema
+    # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html
+    def process_revision_directives(context, revision, directives):
+        if getattr(config.cmd_opts, 'autogenerate', False):
+            script = directives[0]
+            if script.upgrade_ops.is_empty():
+                directives[:] = []
+                logger.info('No changes in schema detected.')
+
+    connectable = current_app.extensions['migrate'].db.get_engine()
+
+    with connectable.connect() as connection:
+        context.configure(
+            connection=connection,
+            target_metadata=target_metadata,
+            process_revision_directives=process_revision_directives,
+            **current_app.extensions['migrate'].configure_args
+        )
+
+        with context.begin_transaction():
+            context.run_migrations()
+
+
+if context.is_offline_mode():
+    run_migrations_offline()
+else:
+    run_migrations_online()

+ 24 - 0
migrations/script.py.mako

@@ -0,0 +1,24 @@
+"""${message}
+
+Revision ID: ${up_revision}
+Revises: ${down_revision | comma,n}
+Create Date: ${create_date}
+
+"""
+from alembic import op
+import sqlalchemy as sa
+${imports if imports else ""}
+
+# revision identifiers, used by Alembic.
+revision = ${repr(up_revision)}
+down_revision = ${repr(down_revision)}
+branch_labels = ${repr(branch_labels)}
+depends_on = ${repr(depends_on)}
+
+
+def upgrade():
+    ${upgrades if upgrades else "pass"}
+
+
+def downgrade():
+    ${downgrades if downgrades else "pass"}

+ 124 - 0
migrations/versions/ece44a2b50ef_.py

@@ -0,0 +1,124 @@
+"""empty message
+
+Revision ID: ece44a2b50ef
+Revises: 
+Create Date: 2021-07-21 11:25:45.628242
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = 'ece44a2b50ef'
+down_revision = None
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('label_provider',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('name', sa.String(), nullable=False),
+    sa.Column('description', sa.String(), nullable=True),
+    sa.Column('root_folder', sa.String(), nullable=False),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('root_folder')
+    )
+    op.create_table('model',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('name', sa.String(), nullable=False),
+    sa.Column('description', sa.String(), nullable=True),
+    sa.Column('root_folder', sa.String(), nullable=False),
+    sa.Column('supports_encoded', sa.String(), nullable=False),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('root_folder')
+    )
+    op.create_table('project',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('name', sa.String(), nullable=False),
+    sa.Column('description', sa.String(), nullable=True),
+    sa.Column('created', sa.DateTime(), nullable=False),
+    sa.Column('model_id', sa.Integer(), nullable=True),
+    sa.Column('label_provider_id', sa.Integer(), nullable=True),
+    sa.Column('root_folder', sa.String(), nullable=False),
+    sa.Column('external_data', sa.Boolean(), nullable=False),
+    sa.Column('data_folder', sa.String(), nullable=False),
+    sa.ForeignKeyConstraint(['label_provider_id'], ['label_provider.id'], ondelete='SET NULL'),
+    sa.ForeignKeyConstraint(['model_id'], ['model.id'], ondelete='SET NULL'),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('root_folder')
+    )
+    op.create_index(op.f('ix_project_created'), 'project', ['created'], unique=False)
+    op.create_table('collection',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('name', sa.String(), nullable=False),
+    sa.Column('project_id', sa.Integer(), nullable=False),
+    sa.Column('reference', sa.String(), nullable=False),
+    sa.Column('description', sa.String(), nullable=True),
+    sa.Column('position', sa.Integer(), nullable=False),
+    sa.Column('autoselect', sa.Boolean(), nullable=False),
+    sa.ForeignKeyConstraint(['project_id'], ['project.id'], ondelete='CASCADE'),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('project_id', 'reference')
+    )
+    op.create_table('label',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('name', sa.String(), nullable=False),
+    sa.Column('project_id', sa.Integer(), nullable=False),
+    sa.Column('parent_id', sa.Integer(), nullable=True),
+    sa.Column('created', sa.DateTime(), nullable=False),
+    sa.Column('reference', sa.String(), nullable=True),
+    sa.ForeignKeyConstraint(['parent_id'], ['label.id'], ondelete='SET NULL'),
+    sa.ForeignKeyConstraint(['project_id'], ['project.id'], ondelete='CASCADE'),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('project_id', 'reference')
+    )
+    op.create_index(op.f('ix_label_created'), 'label', ['created'], unique=False)
+    op.create_table('file',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('name', sa.String(), nullable=False),
+    sa.Column('uuid', sa.String(), nullable=False),
+    sa.Column('extension', sa.String(), nullable=False),
+    sa.Column('type', sa.String(), nullable=False),
+    sa.Column('size', sa.String(), nullable=False),
+    sa.Column('created', sa.DateTime(), nullable=False),
+    sa.Column('path', sa.String(), nullable=False),
+    sa.Column('frames', sa.Integer(), nullable=True),
+    sa.Column('fps', sa.Float(), nullable=True),
+    sa.Column('project_id', sa.Integer(), nullable=False),
+    sa.Column('collection_id', sa.Integer(), nullable=True),
+    sa.ForeignKeyConstraint(['collection_id'], ['collection.id'], ondelete='SET NULL'),
+    sa.ForeignKeyConstraint(['project_id'], ['project.id'], ondelete='CASCADE'),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('project_id', 'path')
+    )
+    op.create_index(op.f('ix_file_created'), 'file', ['created'], unique=False)
+    op.create_table('result',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('file_id', sa.Integer(), nullable=False),
+    sa.Column('origin', sa.String(), nullable=False),
+    sa.Column('type', sa.String(), nullable=False),
+    sa.Column('label_id', sa.Integer(), nullable=True),
+    sa.Column('data', sa.String(), nullable=True),
+    sa.ForeignKeyConstraint(['file_id'], ['file.id'], ondelete='CASCADE'),
+    sa.ForeignKeyConstraint(['label_id'], ['label.id'], ondelete='SET NULL'),
+    sa.PrimaryKeyConstraint('id')
+    )
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.drop_table('result')
+    op.drop_index(op.f('ix_file_created'), table_name='file')
+    op.drop_table('file')
+    op.drop_index(op.f('ix_label_created'), table_name='label')
+    op.drop_table('label')
+    op.drop_table('collection')
+    op.drop_index(op.f('ix_project_created'), table_name='project')
+    op.drop_table('project')
+    op.drop_table('model')
+    op.drop_table('label_provider')
+    # ### end Alembic commands ###

+ 17 - 0
pycs/__init__.py

@@ -0,0 +1,17 @@
+import json
+
+from flask import Flask
+from flask_migrate import Migrate
+from flask_sqlalchemy import SQLAlchemy
+
+
+print('- Loading settings')
+with open('settings.json') as file:
+    settings = json.load(file)
+
+
+app = Flask(__name__)
+app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{settings['database']}"
+app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
+db = SQLAlchemy(app)
+migrate = Migrate(app, db)

+ 26 - 53
pycs/database/Collection.py

@@ -1,70 +1,43 @@
 from contextlib import closing
 from typing import Iterator
 
+from pycs import db
 from pycs.database.File import File
+from pycs.database.base import NamedBaseModel
 
+class Collection(NamedBaseModel):
 
-class Collection:
-    """
-    database class for collections
-    """
+    # table columns
+    project_id = db.Column(
+        db.Integer, db.ForeignKey("project.id", ondelete="CASCADE"), nullable=False)
 
-    def __init__(self, database, row):
-        self.database = database
+    reference = db.Column(
+        db.String, nullable=False)
 
-        self.identifier = row[0]
-        self.project_id = row[1]
-        self.reference = row[2]
-        self.name = row[3]
-        self.description = row[4]
-        self.position = row[5]
-        self.autoselect = False if row[6] == 0 else True
+    description = db.Column(
+        db.String)
 
-    def set_name(self, name: str):
-        """
-        set this collection's name
+    position = db.Column(
+        db.Integer, nullable=False)
 
-        :param name: new name
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE collections SET name = ? WHERE id = ?', (name, self.identifier))
-            self.name = name
+    autoselect = db.Column(
+        db.Boolean, nullable=False)
 
-    def remove(self) -> None:
-        """
-        remove this collection from the database
 
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM collections WHERE id = ?', [self.identifier])
+    # contraints
+    __table_args__ = (
+        db.UniqueConstraint('project_id', 'reference'),
+    )
 
-    def count_files(self) -> int:
-        """
-        count files associated with this project
+    # relationships to other models
+    files = db.relationship("File", backref="collection", lazy=True)
 
-        :return: count
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT COUNT(*) FROM files WHERE project = ? AND collection = ?',
-                           (self.project_id, self.identifier))
-            return cursor.fetchone()[0]
 
-    def files(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
-        """
-        get an iterator of files associated with this collection
+    def count_files(self) -> int:
+        return self.files.count()
 
-        :param offset: file offset
-        :param limit: file limit
-        :return: iterator of files
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT * FROM files
-                WHERE project = ? AND collection = ?
-                ORDER BY id ASC LIMIT ? OFFSET ?
-                ''', (self.project_id, self.identifier, limit, offset))
+    def files_it(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
+        # self.files.filter
+        files = File.query.filter_by(project_id=self.project_id, collection_id=self.id)
+        raise NotImplementedError
 
-            for row in cursor:
-                yield File(self.database, row)

+ 94 - 182
pycs/database/File.py

@@ -1,240 +1,152 @@
+from __future__ import annotations
+
+import json
+import typing as T
+
 from contextlib import closing
-from json import dumps
-from typing import List, Optional
+from datetime import datetime
 
+from pycs import db
 from pycs.database.Result import Result
+from pycs.database.Project import Project
+from pycs.database.Collection import Collection
+from pycs.database.base import NamedBaseModel
 
 
-class File:
-    """
-    database class for files
-    """
-
-    def __init__(self, database, row):
-        self.database = database
-
-        self.identifier = row[0]
-        self.uuid = row[1]
-        self.project_id = row[2]
-        self.collection_id = row[3]
-        self.type = row[4]
-        self.name = row[5]
-        self.extension = row[6]
-        self.size = row[7]
-        self.created = row[8]
-        self.path = row[9]
-        self.frames = row[10]
-        self.fps = row[11]
-
-    def project(self):
-        """
-        get the project associated with this file
+class File(NamedBaseModel):
 
-        :return: project
-        """
-        return self.database.project(self.project_id)
+    # table columns
+    uuid = db.Column(db.String, nullable=False)
 
-    def collection(self):
-        """
-        get the collection associated with this file
+    extension = db.Column(db.String, nullable=False)
 
-        :return: collection
-        """
-        if self.collection_id is None:
-            return None
+    type = db.Column(db.String, nullable=False)
+
+    size = db.Column(db.String, nullable=False)
+
+    created = db.Column(db.DateTime, default=datetime.utcnow,
+        index=True, nullable=False)
+
+    path = db.Column(db.String, nullable=False)
 
-        return self.database.collection(self.collection_id)
+    frames = db.Column(db.Integer)
 
-    def set_collection(self, collection_id: Optional[int]):
+    fps = db.Column(db.Float)
+
+    project_id = db.Column(
+        db.Integer,
+        db.ForeignKey("project.id", ondelete="CASCADE"),
+        nullable=False)
+
+    collection_id = db.Column(
+        db.Integer,
+        db.ForeignKey("collection.id", ondelete="SET NULL"))
+
+    # contraints
+    __table_args__ = (
+        db.UniqueConstraint('project_id', 'path'),
+    )
+
+
+    # relationships to other models
+    results = db.relationship("Result", backref="file", lazy=True)
+
+
+    def set_collection(self, id: T.Optional[int]):
         """
         set this file's collection
 
-        :param collection_id: new collection
+        :param id: new collection id
         :return:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE files SET collection = ? WHERE id = ?',
-                           (collection_id, self.identifier))
-            self.collection_id = collection_id
 
-    def set_collection_by_reference(self, collection_reference: Optional[str]):
+        self.collection_id = id
+        self.commit()
+
+    def set_collection_by_reference(self, collection_reference: T.Optional[str]):
         """
         set this file's collection
 
         :param collection_reference: collection reference
         :return:
         """
-        if collection_reference is None:
+        if self.collection_reference is None:
             self.set_collection(None)
-            return
 
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT id FROM collections WHERE reference = ?', [collection_reference])
-            row = cursor.fetchone()
+        collection = Collection.query.filter_by(reference=collection_reference).one()
+        self.collection = collection
+        self.commit()
 
-        self.set_collection(row[0] if row is not None else None)
-
-    def remove(self) -> None:
+    def _get_another_file(self, *query) -> T.Optional[File]:
         """
-        remove this file from the database
+        get the first file matching the query ordered by descending id
 
-        :return:
+        :return: another file or None
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM files WHERE id = ?', [self.identifier])
+        return File.query.filter(*query)\
+            .order_by(File.id.desc())\
+            .first()
 
-    def previous(self):
+    def next(self) -> T.Optional[File]:
         """
-        get the predecessor of this file
+        get the successor of this file
 
-        :return: another file
+        :return: another file or None
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT * FROM files WHERE id < ? AND project = ? ORDER BY id DESC LIMIT 1
-            ''', (self.identifier, self.project_id))
-            row = cursor.fetchone()
-
-            if row is not None:
-                return File(self.database, row)
+        query = File.id > self.id, Project.id == self.project_id
+        return self._get_another_file(*query)
 
-            return None
 
-    def next(self):
+    def previous(self) -> T.Optional[File]:
         """
-        get the successor of this file
+        get the predecessor of this file
 
-        :return: another file
+        :return: another file or None
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                    SELECT * FROM files WHERE id > ? AND project = ? ORDER BY id ASC LIMIT 1
-                ''', (self.identifier, self.project_id))
-            row = cursor.fetchone()
+        query = File.id < self.id, Project.id == self.project_id
+        return self._get_another_file(*query)
 
-            if row is not None:
-                return File(self.database, row)
 
-            return None
-
-    def previous_in_collection(self):
+    def next_in_collection(self) -> T.Optional[File]:
         """
         get the predecessor of this file
 
-        :return: another file
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            if self.collection_id is None:
-                cursor.execute('''
-                    SELECT * FROM files
-                    WHERE id < ? AND project = ? AND collection IS NULL
-                    ORDER BY id DESC
-                    LIMIT 1
-                ''', (self.identifier, self.project_id))
-            else:
-                cursor.execute('''
-                    SELECT * FROM files
-                    WHERE id < ? AND project = ? AND collection = ?
-                    ORDER BY id DESC
-                    LIMIT 1
-                ''', (self.identifier, self.project_id, self.collection_id))
-
-            row = cursor.fetchone()
-            if row is not None:
-                return File(self.database, row)
-
-            return None
-
-    def next_in_collection(self):
+        :return: another file or None
         """
-        get the successor of this file
+        query = File.id > self.id, Project.id == self.project_id, Collection.id == self.collection_id
+        return self._get_another_file(*query)
 
-        :return: another file
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            if self.collection_id is None:
-                cursor.execute('''
-                    SELECT * FROM files
-                    WHERE id > ? AND project = ? AND collection IS NULL
-                    ORDER BY id ASC
-                    LIMIT 1
-                ''', (self.identifier, self.project_id))
-            else:
-                cursor.execute(''' 
-                    SELECT * FROM files
-                    WHERE id > ? AND project = ? AND collection = ?
-                    ORDER BY id ASC
-                    LIMIT 1
-                ''', (self.identifier, self.project_id, self.collection_id))
-
-            row = cursor.fetchone()
-            if row is not None:
-                return File(self.database, row)
-
-            return None
-
-    def results(self) -> List[Result]:
-        """
-        get a list of all results associated with this file
 
-        :return: list of results
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM results WHERE file = ?', [self.identifier])
-            return list(map(
-                lambda row: Result(self.database, row),
-                cursor.fetchall()
-            ))
-
-    def result(self, identifier) -> Optional[Result]:
+    def previous_in_collection(self) -> T.Optional[File]:
         """
-        get a specific result using its unique identifier
+        get the predecessor of this file
 
-        :param identifier: unique identifier
-        :return: result
+        :return: another file or None
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT * FROM results WHERE id = ? AND file = ?
-            ''', (identifier, self.identifier))
-            row = cursor.fetchone()
+        query = File.id < self.id, Project.id == self.project_id, Collection.id == self.collection_id
+        return self._get_another_file(*query)
 
-            if row is not None:
-                return Result(self.database, row)
 
-            return None
+    def result(self, id: int) -> T.Optional[Result]:
+        return self.results.get(id)
 
-    def create_result(self, origin, result_type, label, data=None):
-        """
-        create a result
 
-        :param origin:
-        :param result_type:
-        :param label:
-        :param data:
-        :return:
-        """
-        if data is not None:
-            data = dumps(data)
+    def create_result(self, origin, result_type, label, data: T.Optional[dict] = None):
+        data = data if data is None else json.dumps(data)
 
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                INSERT INTO results (file, origin, type, label, data)
-                VALUES              (   ?,      ?,    ?,     ?,    ?)
-            ''', (self.identifier, origin, result_type, label, data))
+        result = Result.new(commit=True,
+                            file=self,
+                            origin=origin,
+                            type=result_type,
+                            label=label,
+                            data=data)
+        return result
 
-            return self.result(cursor.lastrowid)
 
-    def remove_results(self, origin='pipeline') -> List[Result]:
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT * FROM results WHERE file = ? AND origin = ?
-            ''', (self.identifier, origin))
+    def remove_results(self, origin='pipeline'):
 
-            results = list(map(lambda row: Result(self.database, row), cursor.fetchall()))
+        results = Result.query.filter(Result.file == self, Result.origin == origin)
 
-            cursor.execute('''
-                DELETE FROM results WHERE file = ? AND origin = ?
-            ''', (self.identifier, origin))
+        results.remove()
 
-            return results
+        return results

+ 38 - 85
pycs/database/Label.py

@@ -1,106 +1,59 @@
+from __future__ import annotations
 from contextlib import closing
+from datetime import datetime
 
+from pycs import db
+from pycs.database.base import NamedBaseModel
 
-class Label:
-    """
-    database class for labels
-    """
+def compare_children(start_label: Label, id: int):
+    """ check for cyclic relationships """
 
-    def __init__(self, database, row):
-        self.database = database
+    labels_to_check = [start_label]
 
-        self.identifier = row[0]
-        self.project_id = row[1]
-        self.parent_id = row[2]
-        self.created = row[3]
-        self.reference = row[4]
-        self.name = row[5]
+    while labels_to_check:
+        label = labels_to_check.pop(0)
 
-    def project(self):
-        """
-        get the project this label is associated with
+        if label.id == id:
+            return False
 
-        :return: project
-        """
-        return self.database.project(self.project_id)
+        labels_to_check.extend(label.children)
 
-    def set_name(self, name: str):
-        """
-        set this labels name
+    return True
 
-        :param name: new name
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE labels SET name = ? WHERE id = ?', (name, self.identifier))
-            self.name = name
-
-    def set_parent(self, parent_id: int):
-        """
-        set this labels parent
+class Label(NamedBaseModel):
 
-        :param parent_id: parent's id
-        :return:
-        """
+    id = db.Column(db.Integer, primary_key=True)
+    project_id = db.Column(
+        db.Integer,
+        db.ForeignKey("project.id", ondelete="CASCADE"),
+        nullable=False)
 
-        # check for cyclic relationships
-        def compare_children(label, identifier):
-            if label.identifier == identifier:
-                return False
+    parent_id = db.Column(
+        db.Integer,
+        db.ForeignKey("label.id", ondelete="SET NULL"))
 
-            for child in label.children():
-                if not compare_children(child, identifier):
-                    return False
+    created = db.Column(db.DateTime, default=datetime.utcnow,
+        index=True, nullable=False)
 
-            return True
+    reference = db.Column(db.String)
 
-        if not compare_children(self, parent_id):
-            raise ValueError('parent_id')
+    # contraints
+    __table_args__ = (
+        db.UniqueConstraint('project_id', 'reference'),
+    )
 
-        # insert parent id
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE labels SET parent = ? WHERE id = ?',
-                           (parent_id, self.identifier))
-            self.parent_id = parent_id
+    # relationships to other models
+    parent = db.relationship("Label", backref="children", remote_side=[id])
 
-    def remove(self):
+    def set_parent(self, parent_id: int):
         """
-        remove this label from the database
+        set this labels parent
 
+        :param parent_id: parent's id
         :return:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM labels WHERE id = ?', [self.identifier])
-
-    def parent(self):
-        """
-        get this labels parent from the database
-
-        :return: parent or None
-        """
-        if self.parent_id is None:
-            return None
-
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM labels WHERE id = ? AND project = ?',
-                           (self.parent_id, self.project_id))
-            row = cursor.fetchone()
-
-            if row is not None:
-                return Label(self.database, row)
-
-        return None
-
-    def children(self):
-        """
-        get this labels children from the database
+        if not compare_children(self, parent_id):
+            raise ValueError('Cyclic relationship detected!')
 
-        :return: list of children
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM labels WHERE parent = ? AND project = ?',
-                           (self.identifier, self.project_id))
-            return list(map(
-                lambda row: Label(self.database, row),
-                cursor.fetchall()
-            ))
+        self.parent_id = parent_id
+        self.commit()

+ 8 - 7
pycs/database/LabelProvider.py

@@ -1,12 +1,13 @@
-class LabelProvider:
+from pycs import db
+from pycs.database.base import NamedBaseModel
+
+class LabelProvider(NamedBaseModel):
     """
     database class for label providers
     """
 
-    def __init__(self, database, row):
-        self.database = database
+    description = db.Column(db.String)
+    root_folder = db.Column(db.String, nullable=False, unique=True)
 
-        self.identifier = row[0]
-        self.name = row[1]
-        self.description = row[2]
-        self.root_folder = row[3]
+    # relationships to other models
+    projects = db.relationship("Project", backref="label_provider", lazy=True)

+ 33 - 51
pycs/database/Model.py

@@ -1,56 +1,38 @@
-from contextlib import closing
-from json import loads, dumps
+import json
 
+from pycs import db
+from pycs.database.base import NamedBaseModel
 
-class Model:
+class Model(NamedBaseModel):
     """
-    database class for label providers
+    database class for ML Models
     """
 
-    def __init__(self, database, row):
-        self.database = database
-
-        self.identifier = row[0]
-        self.name = row[1]
-        self.description = row[2]
-        self.root_folder = row[3]
-        self.supports = loads(row[4])
-
-    def copy_to(self, name: str, root_folder: str):
-        """
-        copies the models database entry while changing name and root_folder
-
-        :param name: copy name
-        :param root_folder: copy root folder
-        :return: copy
-        """
-        supports = dumps(self.supports)
-
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                INSERT INTO models (name, description, root_folder, supports)
-                VALUES (?, ?, ?, ?)
-                ON CONFLICT (root_folder)
-                DO UPDATE SET name = ?, description = ?, supports = ?
-            ''', (name, self.description, root_folder, supports, name, self.description, supports))
-
-            # lastrowid is 0 if on conflict clause applies.
-            # If this is the case we do an extra query to receive the row id.
-            if cursor.lastrowid > 0:
-                row_id = cursor.lastrowid
-                insert = True
-            else:
-                cursor.execute('SELECT id FROM models WHERE root_folder = ?', [root_folder])
-                row_id = cursor.fetchone()[0]
-                insert = False
-
-        return self.database.model(row_id), insert
-
-    def remove(self):
-        """
-        remove this model from the database
-
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM models WHERE id = ?', [self.identifier])
+    description = db.Column(db.String)
+    root_folder = db.Column(db.String, nullable=False, unique=True)
+    supports_encoded = db.Column(db.String, nullable=False)
+
+    # relationships to other models
+    projects = db.relationship("Project", backref="model", lazy=True)
+
+    @property
+    def supports(self):
+        return json.loads(self.supports_encoded)
+
+
+    def copy_to(self, new_name: str, new_root_folder: str):
+
+        model = Model.query.get(root_folder=new_root_folder)
+        is_new = False
+
+        if model is None:
+            model = Model.new(root_folder=new_root_folder)
+            is_new = True
+
+        model.name = name
+        model.description = self.description
+        model.supports_encoded = self.supports_encoded
+
+        self.commit()
+
+        return model, is_new

+ 121 - 249
pycs/database/Project.py

@@ -1,164 +1,108 @@
+import typing as T
+
 from contextlib import closing
+from datetime import datetime
 from os.path import join
-from time import time
 from typing import List, Optional, Tuple, Iterator
 
+from pycs import db
+from pycs.database.base import NamedBaseModel
+
 from pycs.database.Collection import Collection
 from pycs.database.File import File
 from pycs.database.Label import Label
-from pycs.database.LabelProvider import LabelProvider
-from pycs.database.Model import Model
 
 
-class Project:
-    """
-    database class for projects
-    """
+class Project(NamedBaseModel):
+    description = db.Column(db.String)
 
-    def __init__(self, database, row):
-        self.database = database
+    created = db.Column(db.DateTime, default=datetime.utcnow,
+        index=True, nullable=False)
 
-        self.identifier = row[0]
-        self.name = row[1]
-        self.description = row[2]
-        self.created = row[3]
-        self.model_id = row[4]
-        self.label_provider_id = row[5]
-        self.root_folder = row[6]
-        self.external_data = bool(row[7])
-        self.data_folder = row[8]
+    model_id = db.Column(
+        db.Integer,
+        db.ForeignKey("model.id", ondelete="SET NULL"))
 
-    def model(self) -> Model:
-        """
-        get the model this project is associated with
+    label_provider_id = db.Column(
+        db.Integer,
+        db.ForeignKey("label_provider.id", ondelete="SET NULL"))
 
-        :return: model
-        """
-        return self.database.model(self.model_id)
+    root_folder = db.Column(db.String, nullable=False, unique=True)
 
-    def label_provider(self) -> Optional[LabelProvider]:
-        """
-        get the label provider this project is associated with
+    external_data = db.Column(db.Boolean, nullable=False)
 
-        :return: label provider
-        """
-        if self.label_provider_id is not None:
-            return self.database.label_provider(self.label_provider_id)
+    data_folder = db.Column(db.String, nullable=False)
 
-        return None
+    # contraints
+    __table_args__ = ()
 
-    def labels(self) -> List[Label]:
-        """
-        get a list of labels associated with this project
+    # relationships to other models
+    files = db.relationship("File", backref="project", lazy=True)
+    labels = db.relationship("Label", backref="project", lazy=True)
+    collections = db.relationship("Collection", backref="project", lazy=True)
 
-        :return: list of labels
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM labels WHERE project = ?', [self.identifier])
-            return list(map(
-                lambda row: Label(self.database, row),
-                cursor.fetchall()
-            ))
 
-    def label(self, identifier: int) -> Optional[Label]:
+    def label(self, id: int) -> T.Optional[Label]:
         """
         get a label using its unique identifier
 
         :param identifier: unique identifier
         :return: label
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM labels WHERE id = ? AND project = ?',
-                           (identifier, self.identifier))
-            row = cursor.fetchone()
-
-            if row is not None:
-                return Label(self.database, row)
-
-            return None
+        return self.labels.get(id)
 
-    def create_label(self, name: str, reference: str = None,
-                     parent_id: int = None) -> Tuple[Optional[Label], bool]:
+    def file(self, id: int) -> T.Optional[Label]:
         """
-        create a label for this project. If there is already a label with the same reference
-        in the database its name is updated.
+        get a file using its unique identifier
 
-        :param name: label name
-        :param reference: label reference
-        :param parent_id: parent's identifier
-        :return: created or edited label, insert
+        :param identifier: unique identifier
+        :return: file
         """
-        created = int(time())
+        return self.files.get(id)
 
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                INSERT INTO labels (project, parent, created, reference, name)
-                VALUES (?, ?, ?, ?, ?)
-                ON CONFLICT (project, reference) DO
-                UPDATE SET parent = ?, name = ?
-            ''', (self.identifier, parent_id, created, reference, name, parent_id, name))
-
-            # lastrowid is 0 if on conflict clause applies.
-            # If this is the case we do an extra query to receive the row id.
-            if cursor.lastrowid > 0:
-                row_id = cursor.lastrowid
-                insert = True
-            else:
-                cursor.execute('SELECT id FROM labels WHERE project = ? AND reference = ?',
-                               (self.identifier, reference))
-                row_id = cursor.fetchone()[0]
-                insert = False
-
-        return self.label(row_id), insert
-
-    def collections(self) -> List[Collection]:
+    def collection(self, id: int) -> T.Optional[Collection]:
         """
-        get a list of collections associated with this project
+        get a collection using its unique identifier
 
-        :return: list of collections
+        :param identifier: unique identifier
+        :return: collection
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM collections WHERE project = ? ORDER BY position ASC',
-                           [self.identifier])
-
-            return list(map(
-                lambda row: Collection(self.database, row),
-                cursor.fetchall()
-            ))
+        return self.collections.get(id)
 
-    def collection(self, identifier: int) -> Optional[Collection]:
+    def collection_by_reference(self, reference: str) -> T.Optional[Collection]:
         """
         get a collection using its unique identifier
 
         :param identifier: unique identifier
         :return: collection
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM collections WHERE id = ? AND project = ?',
-                           (identifier, self.identifier))
-            row = cursor.fetchone()
+        return self.collections.filter_by(reference=reference).one()
 
-            if row is not None:
-                return Collection(self.database, row)
-
-            return None
-
-    def collection_by_reference(self, reference: str):
+    def create_label(self, name: str, reference: str = None,
+                     parent_id: int = None) -> Tuple[Optional[Label], bool]:
         """
-        get a collection using its reference string
+        create a label for this project. If there is already a label with the same reference
+        in the database its name is updated.
 
-        :param reference: reference string
-        :return: collection
+        :param name: label name
+        :param reference: label reference
+        :param parent_id: parent's identifier
+        :return: created or edited label, insert
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM collections WHERE reference = ? AND project = ?',
-                           (reference, self.identifier))
-            row = cursor.fetchone()
 
-            if row is not None:
-                return Collection(self.database, row)
+        label = Label.query.get(project=self, reference=reference)
+        is_new = False
 
-            return None
+        if label is None:
+            label = Label.new(project=self, reference=reference)
+            is_new = True
+
+        label.set_name(name)
+        label.set_parent(parent_id)
+
+        self.commit()
+
+        return label, is_new
 
     def create_collection(self,
                           reference: str,
@@ -166,74 +110,77 @@ class Project:
                           description: str,
                           position: int,
                           autoselect: bool):
-        autoselect = 1 if autoselect else 0
 
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                INSERT INTO collections
-                    (project, reference, name, description, position, autoselect)
-                VALUES (?, ?, ?, ?, ?, ?)
-                ON CONFLICT (project, reference) DO
-                UPDATE SET name = ?, description = ?, position = ?, autoselect = ?
-            ''', (self.identifier, reference, name, description, position, autoselect,
-                  name, description, position, autoselect))
-
-            # lastrowid is 0 if on conflict clause applies.
-            # If this is the case we do an extra query to receive the row id.
-            if cursor.lastrowid > 0:
-                row_id = cursor.lastrowid
-                insert = True
-            else:
-                cursor.execute('SELECT id FROM collections WHERE project = ? AND reference = ?',
-                               (self.identifier, reference))
-                row_id = cursor.fetchone()[0]
-                insert = False
-
-        return self.collection(row_id), insert
-
-    def remove(self) -> None:
-        """
-        remove this project from the database
 
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM projects WHERE id = ?', [self.identifier])
+        collection = Collection.query.get(project=self, reference=reference)
+        is_new = False
+
+        if collection is None:
+            collection = Collection.new(project=self,
+                                        reference=reference)
+            is_new = True
+
+        collection.name = name
+        collection.description = description
+        collection.position = position
+        collection.autoselect = autoselect
+        self.commit()
+
+        return collection, is_new
 
-    def set_name(self, name: str) -> None:
+    def add_file(self, uuid: str, file_type: str, name: str, extension: str, size: int,
+                 filename: str, frames: int = None, fps: float = None) -> T.Tuple[File, bool]:
         """
-        set this projects name
+        add a file to this project
 
-        :param name: new name
-        :return:
+        :param uuid: unique identifier which is used for temporary files
+        :param file_type: file type (either image or video)
+        :param name: file name
+        :param extension: file extension
+        :param size: file size
+        :param filename: actual name in filesystem
+        :param frames: frame count
+        :param fps: frames per second
+        :return: file
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE projects SET name = ? WHERE id = ?', (name, self.identifier))
-            self.name = name
+        path = join(self.data_folder, filename + extension)
+
+        file = File.objects.get(project=self, path=path)
+        is_new = False
 
-    def set_description(self, description: str) -> None:
+        if file is None:
+            file = File.new(uuid=uuid, project=self, path=path)
+            is_new = True
+
+        file.type = file_type
+        file.name = name
+        file.extension = extension
+        file.size = size
+        file.frames = frames
+        file.fps = fps
+
+        self.commit()
+        return file, is_new
+
+
+    def set_description(self, description: str):
         """
         set this projects description
 
         :param description: new description
         :return:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE projects SET description = ? WHERE id = ?',
-                           (description, self.identifier))
-            self.description = description
-
+        self.description = description
+        self
     def count_files(self) -> int:
         """
         count files associated with this project
 
         :return: count
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT COUNT(*) FROM files WHERE project = ?', [self.identifier])
-            return cursor.fetchone()[0]
+        return self.files.count()
 
-    def files(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
+    def get_files(self, offset: int = 0, limit: int = -1) -> T.Iterator[File]:
         """
         get an iterator of files associated with this project
 
@@ -241,14 +188,7 @@ class Project:
         :param limit: file limit
         :return: iterator of files
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM files WHERE project = ? ORDER BY id ASC LIMIT ? OFFSET ?',
-                           (self.identifier, limit, offset))
-
-            return map(
-                lambda row: File(self.database, row),
-                cursor.fetchall()
-            )
+        return self.files.order_by(File.id.acs()).offset(offset).limit(limit)
 
     def count_files_without_results(self) -> int:
         """
@@ -256,6 +196,8 @@ class Project:
 
         :return: count
         """
+        raise NotImplementedError
+
         with closing(self.database.con.cursor()) as cursor:
             cursor.execute('''
                 SELECT COUNT(*)
@@ -271,6 +213,8 @@ class Project:
 
         :return: list of files
         """
+        raise NotImplementedError
+
         with closing(self.database.con.cursor()) as cursor:
             cursor.execute('''
                 SELECT files.*
@@ -283,91 +227,19 @@ class Project:
             for row in cursor:
                 yield File(self.database, row)
 
-    def count_files_without_collection(self) -> int:
-        """
-        count files associated with this project but with no collection
-
-        :return: count
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT COUNT(*) FROM files WHERE project = ? AND collection IS NULL',
-                           [self.identifier])
-            return cursor.fetchone()[0]
-
-    def files_without_collection(self, offset=0, limit=-1) -> Iterator[File]:
+    def files_without_collection(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
         """
         get an iterator of files without not associated with any collection
 
         :return: list of files
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT * FROM files
-                WHERE files.project = ? AND files.collection IS NULL
-                ORDER BY id ASC
-                LIMIT ? OFFSET ?
-            ''', (self.identifier, limit, offset))
-
-            for row in cursor:
-                yield File(self.database, row)
-
-    def file(self, identifier) -> Optional[File]:
-        """
-        get a file using its unique identifier
-
-        :param identifier: unique identifier
-        :return: file
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('SELECT * FROM files WHERE id = ? AND project = ?',
-                           (identifier, self.identifier))
-            row = cursor.fetchone()
+        return self.get_files(offset, limit).filter(File.collection_id == None)
 
-            if row is not None:
-                return File(self.database, row)
 
-            return None
-
-    def add_file(self, uuid: str, file_type: str, name: str, extension: str, size: int,
-                 filename: str, frames: int = None, fps: float = None) -> Tuple[File, bool]:
+    def count_files_without_collection(self) -> int:
         """
-        add a file to this project
+        count files associated with this project but with no collection
 
-        :param uuid: unique identifier which is used for temporary files
-        :param file_type: file type (either image or video)
-        :param name: file name
-        :param extension: file extension
-        :param size: file size
-        :param filename: actual name in filesystem
-        :param frames: frame count
-        :param fps: frames per second
-        :return: file
+        :return: count
         """
-        created = int(time())
-        path = join(self.data_folder, filename + extension)
-
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                INSERT INTO files (
-                    uuid, project, type, name, extension, size, created, path, frames, fps
-                )
-                VALUES (  
-                       ?,       ?,    ?,    ?,         ?,    ?,       ?,    ?,      ?,   ?
-                )
-                ON CONFLICT (project, path) DO
-                UPDATE SET type = ?, name = ?, extension = ?, size = ?, frames = ?, fps = ?
-            ''', (uuid, self.identifier, file_type, name, extension, size, created, path, frames,
-                  fps, file_type, name, extension, size, frames, fps))
-
-            # lastrowid is 0 if on conflict clause applies.
-            # If this is the case we do an extra query to receive the row id.
-            if cursor.lastrowid > 0:
-                row_id = cursor.lastrowid
-                insert = True
-            else:
-                cursor.execute('SELECT id FROM files WHERE project = ? AND path = ?',
-                               (self.identifier, path))
-                row_id = cursor.fetchone()[0]
-                insert = False
-
-        return self.file(row_id), insert
+        return self.files_without_collection().count()

+ 29 - 37
pycs/database/Result.py

@@ -1,31 +1,28 @@
-from contextlib import closing
+import typing as T
 
+from contextlib import closing
 from json import dumps, loads
 
+from pycs import db
+from pycs.database.base import BaseModel
 
-class Result:
-    """
-    database class for results
-    """
+class Result(BaseModel):
 
-    def __init__(self, database, row):
-        self.database = database
+    file_id = db.Column(
+        db.Integer,
+        db.ForeignKey("file.id", ondelete="CASCADE"),
+        nullable=False)
 
-        self.identifier = row[0]
-        self.file_id = row[1]
-        self.origin = row[2]
-        self.type = row[3]
-        self.label = row[4]
-        self.data = loads(row[5]) if row[5] is not None else None
+    origin = db.Column(db.String, nullable=False)
+    type = db.Column(db.String, nullable=False)
 
-    def remove(self):
-        """
-        remove this result from the database
+    label_id = db.Column(
+        db.Integer,
+        db.ForeignKey("label.id", ondelete="SET NULL"),
+        nullable=True)
+
+    data = db.Column(db.String)
 
-        :return:
-        """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('DELETE FROM results WHERE id = ?', [self.identifier])
 
     def set_origin(self, origin: str):
         """
@@ -34,33 +31,28 @@ class Result:
         :param origin: either 'user' or 'pipeline'
         :return:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE results SET origin = ? WHERE id = ?', (origin, self.identifier))
-            self.origin = origin
+        self.origin = origin
+        self.commit()
+
 
     def set_label(self, label: int):
         """
-        set this results label
+        set this results origin
 
-        :param label: label id
+        :param label: label ID
         :return:
         """
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE results SET label = ? WHERE id = ?', (label, self.identifier))
-            self.label = label
+        self.label_id = label
+        self.commit()
 
-    def set_data(self, data: dict):
+    def set_data(self, data: T.Optional[dict]):
         """
         set this results data object
 
         :param data: data object
         :return:
         """
-        if data is None:
-            data_txt = None
-        else:
-            data_txt = dumps(data)
-
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('UPDATE results SET data = ? WHERE id = ?', (data_txt, self.identifier))
-            self.data = data
+        data = data if data is None else json.dumps(data)
+
+        self.data = data
+        self.commit()

+ 37 - 0
pycs/database/base.py

@@ -0,0 +1,37 @@
+
+from pycs import db
+
+class BaseModel(db.Model):
+    __abstract__ = True
+
+    id = db.Column(db.Integer, primary_key=True)
+
+
+    def remove(self) -> None:
+        """
+        remove this instance from the database
+
+        :return:
+        """
+        db.session.delete(self)
+        self.commit()
+
+    @classmethod
+    def new(cls, commit=False, **kwargs):
+        obj = cls(**kwargs)
+        db.session.add(obj)
+
+        if commit:
+            self.commit()
+
+    def commit(self):
+        db.session.commit()
+
+class NamedBaseModel(BaseModel):
+    __abstract__ = True
+
+    name = db.Column(db.String, nullable=False)
+
+    def set_name(self, name: str):
+        self.name = name
+        self.commit()

+ 11 - 16
pycs/frontend/WebServer.py

@@ -1,16 +1,10 @@
-# import eventlet
-# import socketio
-from flask_socketio import SocketIO
+import os
+import glob
 
-from glob import glob
-from os import getcwd
-from os import path
-from os.path import exists
 from logging.config import dictConfig
 
-from flask import Flask
-from flask import send_from_directory
 from flask import send_from_directory
+from flask_socketio import SocketIO
 
 from pycs.database.Database import Database
 from pycs.frontend.endpoints.ListJobs import ListJobs
@@ -61,12 +55,11 @@ class WebServer:
     """
 
     # pylint: disable=line-too-long
-    def __init__(self, settings: dict):
+    def __init__(self, app, settings: dict):
 
         dictConfig(settings["logging"])
         # initialize flask app instance
-        self.app = Flask(__name__)
-
+        self.app = app
 
         # initialize database
         db_file = settings["database"]
@@ -81,7 +74,9 @@ class WebServer:
         self.logger.info('Creating pipeline cache')
         self.pipelines = PipelineCache(self.jobs)
 
-        init_func = self.production_init if exists('webui/index.html') else self.development_init
+        PRODUCTION = os.path.exists('webui/index.html')
+
+        init_func = self.production_init if PRODUCTION else self.development_init
         kwargs, static_files = init_func()
 
         self.sio = SocketIO(self.app, **kwargs)#socketio.Server(**kwargs)
@@ -135,7 +130,7 @@ class WebServer:
         @self.app.route('/', methods=['GET'])
         def index():
             # pylint: disable=unused-variable
-            return send_from_directory(path.join(getcwd(), 'webui'), 'index.html')
+            return send_from_directory(os.path.join(os.getcwd(), 'webui'), 'index.html')
 
         return kwargs, self.static_files
 
@@ -144,12 +139,12 @@ class WebServer:
         # find static files and folders
         static_files = {}
 
-        for file_path in glob('webui/*'):
+        for file_path in glob.glob('webui/*'):
             file_path = file_path.replace('\\', '/')
             static_files[file_path[5:]] = file_path
 
         # separately add svg files and set their correct mime type
-        for svg_path in glob('webui/img/*.svg'):
+        for svg_path in glob.glob('webui/img/*.svg'):
             svg_path = svg_path.replace('\\', '/')
             static_files[svg_path[5:]] = {'content_type': 'image/svg+xml', 'filename': svg_path}
 

+ 2 - 0
requirements.txt

@@ -5,6 +5,8 @@ scipy
 eventlet
 flask
 flask-socketio
+flask-sqlalchemy
+flask-migrate
 # python-socketio
 munch
 scikit-image