Przeglądaj źródła

reworked the models to the (hopefully) final state

Dimitri Korsch 3 lat temu
rodzic
commit
714f6d90cf

+ 41 - 2
pycs/database/Collection.py

@@ -8,7 +8,9 @@ class Collection(NamedBaseModel):
 
     # table columns
     project_id = db.Column(
-        db.Integer, db.ForeignKey("project.id", ondelete="CASCADE"), nullable=False)
+        db.Integer,
+        db.ForeignKey("project.id", ondelete="CASCADE"),
+        nullable=False)
 
     reference = db.Column(
         db.String, nullable=False)
@@ -29,8 +31,45 @@ class Collection(NamedBaseModel):
     )
 
     # relationships to other models
-    files = db.relationship("File", backref="collection", lazy=True)
+    files = db.relationship("File", backref="collection", lazy="dynamic")
 
+    serialize_only = NamedBaseModel.serialize_only + (
+        "project_id",
+        "reference",
+        "description",
+        "position",
+        "autoselect",
+    )
 
     def count_files(self) -> int:
         return self.files.count()
+
+    def get_files(self, offset: int = 0, limit: int = -1):
+        """
+        get an iterator of files associated with this project
+
+        :param offset: file offset
+        :param limit: file limit
+        :return: iterator of files
+        """
+        from pycs.database.File import File
+        return self.files.order_by(File.id).offset(offset).limit(limit)
+
+    @staticmethod
+    def update_autoselect(collections: List[Collection]) -> List[Collection]:
+        """ disable autoselect if there are no elements in the collection """
+
+        found = False
+
+        for collection in collections:
+            if not collection.autoselect:
+                continue
+
+            if found:
+                collection.autoselect = False
+
+            elif collection.count_files() == 0:
+                collection.autoselect = False
+                found = True
+
+        return collections

+ 71 - 35
pycs/database/File.py

@@ -1,10 +1,9 @@
 from __future__ import annotations
 
-import json
-import os
 import typing as T
 
 from datetime import datetime
+from pathlib import Path
 
 from pycs import db
 from pycs.database.Collection import Collection
@@ -51,27 +50,46 @@ class File(NamedBaseModel):
     )
 
 
-    # relationships to other models
-    results = db.relationship("Result", backref="file", lazy=True)
+    results = db.relationship("Result", backref="file",
+        lazy="dynamic", passive_deletes=True)
 
 
+    serialize_only = NamedBaseModel.serialize_only + (
+        "uuid",
+        "extension",
+        "type",
+        "size",
+        "created",
+        "path",
+        "frames",
+        "fps",
+        "project_id",
+        "collection_id",
+    )
+
+    @property
+    def filename(self):
+        return f"{self.name}{self.extension}"
+
     @property
-    def absolute_path(self):
-        if os.path.isabs(self.path):
-            return self.path
+    def absolute_path(self) -> str:
+        path = Path(self.path)
 
-        return os.path.join(os.getcwd(), self.path)
+        if path.is_absolute():
+            return str(path)
+
+        return str(Path.cwd() / path)
 
     @commit_on_return
-    def set_collection(self, id: T.Optional[int]):
+    def set_collection(self, collection_id: T.Optional[int]):
         """
         set this file's collection
 
-        :param id: new collection id
+        :param collection_id: new collection id
         :return:
         """
 
-        self.collection_id = id
+        self.collection_id = collection_id
 
     @commit_on_return
     def set_collection_by_reference(self, collection_reference: T.Optional[str]):
@@ -85,7 +103,7 @@ class File(NamedBaseModel):
             self.set_collection(None)
 
         collection = Collection.query.filter_by(reference=collection_reference).one()
-        self.collection = collection
+        self.collection_id = collection.id
 
     def _get_another_file(self, *query) -> T.Optional[File]:
         """
@@ -93,9 +111,7 @@ class File(NamedBaseModel):
 
         :return: another file or None
         """
-        return File.query.filter(File.project_id == self.project_id, *query)\
-            .order_by(File.id.desc())\
-            .first()
+        return File.query.filter(File.project_id == self.project_id, *query)
 
     def next(self) -> T.Optional[File]:
         """
@@ -103,8 +119,9 @@ class File(NamedBaseModel):
 
         :return: another file or None
         """
-        query = File.id > self.id,
-        return self._get_another_file(*query)
+
+        return self._get_another_file(File.id > self.id)\
+            .order_by(File.id).first()
 
 
     def previous(self) -> T.Optional[File]:
@@ -113,8 +130,9 @@ class File(NamedBaseModel):
 
         :return: another file or None
         """
-        query = File.id < self.id,
-        return self._get_another_file(*query)
+
+        return self._get_another_file(File.id < self.id)\
+            .order_by(File.id.desc()).first()
 
 
     def next_in_collection(self) -> T.Optional[File]:
@@ -123,8 +141,9 @@ class File(NamedBaseModel):
 
         :return: another file or None
         """
-        query = File.id > self.id, File.collection_id == self.collection_id
-        return self._get_another_file(*query)
+        return self._get_another_file(
+            File.id > self.id, File.collection_id == self.collection_id)\
+            .order_by(File.id).first()
 
 
     def previous_in_collection(self) -> T.Optional[File]:
@@ -133,30 +152,47 @@ class File(NamedBaseModel):
 
         :return: another file or None
         """
-        query = File.id < self.id, File.collection_id == self.collection_id
-        return self._get_another_file(*query)
+        return self._get_another_file(
+            File.id < self.id, File.collection_id == self.collection_id)\
+            .order_by(File.id.desc()).first()
 
 
     def result(self, id: int) -> T.Optional[Result]:
         return self.results.get(id)
 
 
-    def create_result(self, origin, result_type, label, data: T.Optional[dict] = None):
-        data = data if data is None else json.dumps(data)
-
-        result = Result.new(commit=True,
-                            file=self,
+    @commit_on_return
+    def create_result(self,
+                      origin: str,
+                      result_type: str,
+                      label: T.Optional[T.Union[Label, int]] = None,
+                      data: T.Optional[dict] = None) -> Result:
+
+        result = Result.new(commit=False,
+                            file_id=self.id,
                             origin=origin,
-                            type=result_type,
-                            label=label,
-                            data=data)
+                            type=result_type)
+
+        result.data = data
+
+        if label is not None:
+            assert isinstance(label, (int, Label)), f"Wrong label type: {type(label)}"
+
+            if isinstance(label, Label):
+                label = label.id
+
+            result.label_id = label
+
         return result
 
 
-    def remove_results(self, origin='pipeline'):
+    def remove_results(self, origin='pipeline') -> T.List[Result]:
 
-        results = Result.query.filter(Result.file == self, Result.origin == origin)
+        results = Result.query.filter(
+            Result.file_id == self.id,
+            Result.origin == origin)
 
-        results.remove()
+        _results = results.all()
+        results.delete()
 
-        return results
+        return _results

+ 21 - 5
pycs/database/Label.py

@@ -1,12 +1,10 @@
-from __future__ import annotations
-from contextlib import closing
 from datetime import datetime
 
 from pycs import db
 from pycs.database.base import NamedBaseModel
 from pycs.database.util import commit_on_return
 
-def compare_children(start_label: Label, id: int):
+def compare_children(start_label: Label, id: int) -> bool:
     """ check for cyclic relationships """
 
     labels_to_check = [start_label]
@@ -21,6 +19,9 @@ def compare_children(start_label: Label, id: int):
 
     return True
 
+def _Label_id():
+    return Label.id
+
 class Label(NamedBaseModel):
 
     id = db.Column(db.Integer, primary_key=True)
@@ -44,10 +45,25 @@ class Label(NamedBaseModel):
     )
 
     # relationships to other models
-    parent = db.relationship("Label", backref="children", remote_side=[id])
+    parent = db.relationship("Label",
+        backref="children",
+        remote_side=_Label_id)
+
+    results = db.relationship("Result",
+        backref="label",
+        passive_deletes=True,
+        lazy="dynamic",
+    )
+
+    serialize_only = NamedBaseModel.serialize_only + (
+        "project_id",
+        "parent_id",
+        "reference",
+        "children",
+    )
 
     @commit_on_return
-    def set_parent(self, parent_id: int, commit: bool = True):
+    def set_parent(self, parent_id: int) -> None:
         """
         set this labels parent
 

+ 68 - 5
pycs/database/LabelProvider.py

@@ -1,5 +1,12 @@
+import json
+import re
+
+from pathlib import Path
+
 from pycs import db
 from pycs.database.base import NamedBaseModel
+from pycs.interfaces.LabelProvider import LabelProvider as LabelProviderInterface
+
 
 class LabelProvider(NamedBaseModel):
     """
@@ -8,13 +15,51 @@ class LabelProvider(NamedBaseModel):
 
     description = db.Column(db.String)
     root_folder = db.Column(db.String, nullable=False, unique=True)
+    configuration_file = db.Column(db.String, nullable=False)
 
     # relationships to other models
-    projects = db.relationship("Project", backref="label_provider", lazy=True)
+    projects = db.relationship("Project", backref="label_provider", lazy="dynamic")
+
+    # contraints
+    __table_args__ = (
+        db.UniqueConstraint('root_folder', 'configuration_file'),
+    )
+
+    serialize_only = NamedBaseModel.serialize_only + (
+        "description",
+        "root_folder",
+        "configuration_file",
+    )
+
+    @classmethod
+    def discover(cls, root: Path):
+
+        for folder, conf_path in _find_files(root):
+            with open(conf_path) as f:
+                config = json.load(f)
+
+            provider, _ = cls.get_or_create(
+                root_folder=str(folder),
+                configuration_file=conf_path.name
+            )
+
+            provider.name = config['name']
+
+            # returns None if not present
+            provider.description = config.get('description')
+
+            db.session.flush()
+        db.session.commit()
 
     @property
-    def configuration_path(self):
-        return path.join(self.root_folder, self.configuration_file)
+    def root(self) -> Path:
+        return Path(self.root_folder)
+
+
+    @property
+    def configuration_file_path(self) -> str:
+        return str(self.root / self.configuration_file)
+
 
     def load(self) -> LabelProviderInterface:
         """
@@ -23,11 +68,11 @@ class LabelProvider(NamedBaseModel):
         :return: LabelProvider instance
         """
         # load configuration.json
-        with open(self.configuration_path, 'r') as configuration_file:
+        with open(self.configuration_file_path) as configuration_file:
             configuration = json.load(configuration_file)
 
         # load code
-        code_path = path.join(self.root_folder, configuration['code']['module'])
+        code_path = str(self.root / configuration['code']['module'])
         module_name = code_path.replace('/', '.').replace('\\', '.')
         class_name = configuration['code']['class']
 
@@ -36,3 +81,21 @@ class LabelProvider(NamedBaseModel):
 
         # return instance
         return class_attr(self.root_folder, configuration)
+
+
+def _find_files(root: str, config_regex=re.compile(r'^configuration(\d+)?\.json$')):
+    # list folders in labels/
+    for folder in Path(root).glob('*'):
+        # list files
+        for file_path in folder.iterdir():
+
+            # filter configuration files
+            if not file_path.is_file():
+                continue
+
+            if config_regex.match(file_path.name) is None:
+                continue
+
+            # yield element
+            yield folder, file_path
+

+ 45 - 7
pycs/database/Model.py

@@ -1,5 +1,7 @@
 import json
 
+from pathlib import Path
+
 from pycs import db
 from pycs.database.base import NamedBaseModel
 from pycs.database.util import commit_on_return
@@ -14,22 +16,58 @@ class Model(NamedBaseModel):
     supports_encoded = db.Column(db.String, nullable=False)
 
     # relationships to other models
-    projects = db.relationship("Project", backref="model", lazy=True)
+    projects = db.relationship("Project", backref="model", lazy="dynamic")
+
+    serialize_only = NamedBaseModel.serialize_only + (
+        "description",
+        "root_folder",
+    )
+
+
+    def serialize(self):
+        result = super().serialize()
+        result["supports"] = self.supports
+        return result
+
+    @classmethod
+    def discover(cls, root: Path, config_name: str = "configuration.json"):
+        for folder in Path(root).glob("*"):
+            with open(folder / config_name) as f:
+                config = json.load(f)
+
+            # extract data
+            name = config['name']
+            description = config.get('description', None)
+            supports = config['supports']
+
+            model, _ = cls.get_or_create(root_folder=str(folder))
+
+            model.name = name
+            model.description = description
+            model.supports = supports
+
+            db.session.flush()
+        db.session.commit()
 
     @property
     def supports(self):
         return json.loads(self.supports_encoded)
 
+    @supports.setter
+    def supports(self, value):
+        if isinstance(value, str):
+            self.supports_encoded = value
+
+        elif isinstance(value, (dict, list)):
+            self.supports_encoded = json.dumps(value)
+
+        else:
+            raise ValueError(f"Not supported type: {type(value)}")
 
     @commit_on_return
     def copy_to(self, new_name: str, new_root_folder: str):
 
-        model = Model.query.get(root_folder=new_root_folder)
-        is_new = False
-
-        if model is None:
-            model = Model.new(root_folder=new_root_folder)
-            is_new = True
+        model, is_new = Model.get_or_create(root_folder=new_root_folder)
 
         model.name = name
         model.description = self.description

+ 145 - 108
pycs/database/Project.py

@@ -1,12 +1,8 @@
+import os
 import typing as T
+import warnings
 
-from contextlib import closing
 from datetime import datetime
-from os.path import join
-from typing import Iterator
-from typing import List
-from typing import Optional
-from typing import Tuple
 
 from pycs import db
 from pycs.database.base import NamedBaseModel
@@ -15,7 +11,7 @@ from pycs.database.Collection import Collection
 from pycs.database.File import File
 from pycs.database.Label import Label
 from pycs.database.util import commit_on_return
-from pycs.database.util.TreeNodeLabel import TreeNodeLabel
+
 
 class Project(NamedBaseModel):
     description = db.Column(db.String)
@@ -41,76 +37,74 @@ class Project(NamedBaseModel):
     __table_args__ = ()
 
     # relationships to other models
-    files = db.relationship("File", backref="project", lazy=True)
-    labels = db.relationship("Label", backref="project", lazy=True)
-    collections = db.relationship("Collection", backref="project", lazy=True)
-
-
-    def label(self, id: int) -> T.Optional[Label]:
+    files = db.relationship(
+        "File",
+        backref="project",
+        lazy="dynamic")
+
+    labels = db.relationship(
+        "Label",
+        backref="project",
+        lazy="dynamic")
+
+    collections = db.relationship(
+        "Collection",
+        backref="project",
+        lazy="dynamic")
+
+
+    serialize_only = NamedBaseModel.serialize_only + (
+        "created",
+        "description",
+        "model_id",
+        "label_provider_id",
+        "root_folder",
+        "external_data",
+        "data_folder",
+    )
+
+
+    def label(self, identifier: int) -> T.Optional[Label]:
         """
         get a label using its unique identifier
 
         :param identifier: unique identifier
         :return: label
         """
-        return self.labels.get(id)
+        return self.labels.filter(Label.id == identifier).one_or_none()
 
-    def file(self, id: int) -> T.Optional[Label]:
-        """
-        get a file using its unique identifier
 
-        :param identifier: unique identifier
-        :return: file
+    def label_by_reference(self, reference: str) -> T.Optional[Label]:
         """
-        return self.files.get(id)
+        get a label using its reference string
 
-    def collection(self, id: int) -> T.Optional[Collection]:
+        :param reference: reference string
+        :return: label
         """
-        get a collection using its unique identifier
+        return self.labels.filter(Label.reference == reference).one_or_none()
 
-        :param identifier: unique identifier
-        :return: collection
-        """
-        return self.collections.get(id)
 
-    def collection_by_reference(self, reference: str) -> T.Optional[Collection]:
+    def file(self, identifier: int) -> T.Optional[Label]:
         """
-        get a collection using its unique identifier
+        get a file using its unique identifier
 
         :param identifier: unique identifier
-        :return: collection
+        :return: file
         """
-        return self.collections.filter_by(reference=reference).one()
+        return self.files.filter(File.id == identifier).one_or_none()
 
-    @commit_on_return
-    def create_label(self, name: str, reference: str = None,
-                     parent_id: int = None,
-                     hierarchy_level: str = None) -> Tuple[Optional[Label], bool]:
-        """
-        create a label for this project. If there is already a label with the same reference
-        in the database its name is updated.
 
-        :param name: label name
-        :param reference: label reference
-        :param parent_id: parent's identifier
-        :param hierarchy_level: hierarchy level name
-        :return: created or edited label, insert
+    def label_tree(self) -> T.List[Label]:
         """
+        get a list of root labels associated with this project
 
-        label = Label.query.get(project=self, reference=reference)
-        is_new = False
-
-        if label is None:
-            label = Label.new(project=self, reference=reference)
-            is_new = True
-
-        label.set_name(name, commit=False)
-        label.set_parent(parent_id, commit=False)
-        label.hierarchy_level = hierarchy_level
+        :return: list of labels
+        """
+        warnings.warn("Check performance of this method!")
+        return self.labels.filter(Label.parent_id == None).all()
 
-        return label, is_new
 
-    def label_tree(self) -> List[TreeNodeLabel]:
+    def label_tree_original(self):
         """
         get a list of root labels associated with this project
 
@@ -144,13 +138,64 @@ class Project(NamedBaseModel):
 
             return result
 
+
+    def collection(self, identifier: int) -> T.Optional[Collection]:
+        """
+        get a collection using its unique identifier
+
+        :param identifier: unique identifier
+        :return: collection
+        """
+        return self.collections.filter(Collection.id == identifier).one_or_none()
+
+
+    def collection_by_reference(self, reference: str) -> T.Optional[Collection]:
+        """
+        get a collection using its unique identifier
+
+        :param identifier: unique identifier
+        :return: collection
+        """
+        return self.collections.filter(Collection.reference == reference).one_or_none()
+
+
+    @commit_on_return
+    def create_label(self, name: str,
+                     reference: str = None,
+                     parent_id: int = None,
+                     hierarchy_level: str = None) -> T.Tuple[T.Optional[Label], bool]:
+        """
+        create a label for this project. If there is already a label with the same reference
+        in the database its name is updated.
+
+        :param name: label name
+        :param reference: label reference
+        :param parent_id: parent's identifier
+        :param hierarchy_level: hierarchy level name
+        :return: created or edited label, insert
+        """
+
+        label = Label.query.get(project=self, reference=reference)
+        is_new = False
+
+        if label is None:
+            label = Label.new(project=self, reference=reference)
+            is_new = True
+
+        label.set_name(name, commit=False)
+        label.set_parent(parent_id, commit=False)
+        label.hierarchy_level = hierarchy_level
+
+        return label, is_new
+
+
     @commit_on_return
     def create_collection(self,
                           reference: str,
                           name: str,
                           description: str,
                           position: int,
-                          autoselect: bool) -> Tuple[Collection, bool]:
+                          autoselect: bool) -> T.Tuple[Collection, bool]:
         """
         create a new collection associated with this project
 
@@ -163,13 +208,9 @@ class Project(NamedBaseModel):
         :return: collection object, insert
         """
 
-        collection = Collection.query.get(project=self, reference=reference)
-        is_new = False
 
-        if collection is None:
-            collection = Collection.new(project=self,
-                                        reference=reference)
-            is_new = True
+        collection, is_new = Collection.get_or_create(
+            project_id=self.id, reference=reference)
 
         collection.name = name
         collection.description = description
@@ -178,9 +219,17 @@ class Project(NamedBaseModel):
 
         return collection, is_new
 
+
     @commit_on_return
-    def add_file(self, uuid: str, file_type: str, name: str, extension: str, size: int,
-                 filename: str, frames: int = None, fps: float = None) -> T.Tuple[File, bool]:
+    def add_file(self,
+                 uuid: str,
+                 file_type: str,
+                 name: str,
+                 extension: str,
+                 size: int,
+                 filename: str,
+                 frames: int = None,
+                 fps: float = None) -> T.Tuple[File, bool]:
         """
         add a file to this project
 
@@ -194,14 +243,10 @@ class Project(NamedBaseModel):
         :param fps: frames per second
         :return: file
         """
-        path = join(self.data_folder, filename + extension)
-
-        file = File.objects.get(project=self, path=path)
-        is_new = False
+        path = os.path.join(self.data_folder, f"{filename}{extension}")
 
-        if file is None:
-            file = File.new(uuid=uuid, project=self, path=path)
-            is_new = True
+        file, is_new = File.get_or_create(
+            project_id=self.id, path=path)
 
         file.type = file_type
         file.name = name
@@ -213,15 +258,6 @@ class Project(NamedBaseModel):
         return file, is_new
 
 
-    def set_description(self, description: str):
-        """
-        set this projects description
-
-        :param description: new description
-        :return:
-        """
-        self.description = description
-        self
     def count_files(self) -> int:
         """
         count files associated with this project
@@ -230,7 +266,8 @@ class Project(NamedBaseModel):
         """
         return self.files.count()
 
-    def get_files(self, offset: int = 0, limit: int = -1) -> T.Iterator[File]:
+
+    def get_files(self, offset: int = 0, limit: int = -1) -> T.List[File]:
         """
         get an iterator of files associated with this project
 
@@ -238,7 +275,17 @@ class Project(NamedBaseModel):
         :param limit: file limit
         :return: iterator of files
         """
-        return self.files.order_by(File.id.acs()).offset(offset).limit(limit)
+        return self.files.order_by(File.id).offset(offset).limit(limit).all()
+
+
+    def _files_without_results(self):
+        """
+        get files without any results
+
+        :return: a query object
+        """
+        return self.files.filter(~File.results.any())
+
 
     def count_files_without_results(self) -> int:
         """
@@ -246,50 +293,40 @@ class Project(NamedBaseModel):
 
         :return: count
         """
-        raise NotImplementedError
 
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT COUNT(*)
-                FROM files
-                LEFT JOIN results ON files.id = results.file
-                WHERE files.project = ? AND results.id IS NULL
-            ''', [self.identifier])
-            return cursor.fetchone()[0]
+        return self._files_without_results().count()
 
-    def files_without_results(self) -> Iterator[File]:
+
+    def files_without_results(self) -> T.List[File]:
         """
-        get an iterator of files without associated results
+        get a list of files without associated results
 
         :return: list of files
         """
-        raise NotImplementedError
+        return self._files_without_results().all()
 
-        with closing(self.database.con.cursor()) as cursor:
-            cursor.execute('''
-                SELECT files.*
-                FROM files
-                LEFT JOIN results ON files.id = results.file
-                WHERE files.project = ? AND results.id IS NULL
-                ORDER BY id ASC
-            ''', [self.identifier])
 
-            for row in cursor:
-                yield File(self.database, row)
+    def _files_without_collection(self, offset: int = 0, limit: int = -1):
+        """
+        get files without a collection
 
-    def files_without_collection(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
+        :return: a query object
         """
-        get an iterator of files without not associated with any collection
+        return self.get_files(offset, limit).filter(File.collection_id == None)
+
+    def files_without_collection(self, offset: int = 0, limit: int = -1) -> T.List[File]:
+        """
+        get a list of files without a collection
 
         :return: list of files
         """
-        return self.get_files(offset, limit).filter(File.collection_id == None)
+        return self._files_without_collection(offset=offset, limit=limit).all()
 
 
     def count_files_without_collection(self) -> int:
         """
-        count files associated with this project but with no collection
+        count files associated with this project but without a collection
 
         :return: count
         """
-        return self.files_without_collection().count()
+        return self._files_without_collection().count()

+ 28 - 12
pycs/database/Result.py

@@ -22,8 +22,35 @@ class Result(BaseModel):
         db.ForeignKey("label.id", ondelete="SET NULL"),
         nullable=True)
 
-    data = db.Column(db.String)
+    data_encoded = db.Column(db.String)
 
+    serialize_only = BaseModel.serialize_only + (
+        "file_id",
+        "origin",
+        "type",
+        "label_id",
+        "data",
+    )
+
+    def serialize(self):
+        result = super().serialize()
+        result["data"] = self.data
+        return result
+
+    @property
+    def data(self):
+        return None if self.data_encoded is None else json.loads(self.data_encoded)
+
+    @data.setter
+    def data(self, value):
+        if isinstance(value, str) or value is None:
+            self.data_encoded = value
+
+        elif isinstance(value, (dict, list)):
+            self.data_encoded = json.dumps(value)
+
+        else:
+            raise ValueError(f"Not supported type: {type(value)}")
 
     @commit_on_return
     def set_origin(self, origin: str):
@@ -46,14 +73,3 @@ class Result(BaseModel):
         """
         self.label_id = label
 
-    @commit_on_return
-    def set_data(self, data: T.Optional[dict]):
-        """
-        set this results data object
-
-        :param data: data object
-        :return:
-        """
-        data = data if data is None else json.dumps(data)
-
-        self.data = data

+ 68 - 7
pycs/database/base.py

@@ -1,40 +1,101 @@
+from __future__ import annotations
 
+import datetime
+import typing as T
+
+from flask import abort
+from sqlalchemy_serializer import SerializerMixin
+
+from pycs import app
 from pycs import db
 from pycs.database.util import commit_on_return
 
-class BaseModel(db.Model):
+class BaseModel(db.Model, SerializerMixin):
     __abstract__ = True
 
+    # setup of the SerializerMixin
+    date_format = '%s'  # Unixtimestamp (seconds)
+    datetime_format = '%d. %b. %Y %H:%M:%S'
+    time_format = '%H:%M'
+
+
     id = db.Column(db.Integer, primary_key=True)
 
+    serialize_only = ("id",)
+
+    def __repr__(self):
+        attrs = self.serialize()
+        content = ", ".join([f"{attr}={value}" for attr, value in attrs.items()])
+        return f"<{self.__class__.__name__}: {content}>"
+
+
+    def serialize(self) -> dict:
+        return self.to_dict()
+
 
-    def remove(self, commit: bool = True) -> None:
+    @commit_on_return
+    def delete(self) -> dict:
         """
-        remove this instance from the database
+        delete this instance from the database
 
-        :return:
+        :return: serialized self
         """
         db.session.delete(self)
+        dump = self.serialize()
+
+        return dump
+
+
+    # do an alias
+    remove = delete
 
-        if commit:
-            self.commit()
 
     @classmethod
-    def new(cls, commit=False, **kwargs):
+    def new(cls, commit: bool = True, **kwargs):
         obj = cls(**kwargs)
         db.session.add(obj)
 
         if commit:
             self.commit()
 
+    @classmethod
+    def get_or_create(cls, **kwargs) -> T.Tuple[BaseModel, bool]:
+
+        is_new = False
+
+        obj = cls.query.filter_by(**kwargs).one_or_none()
+
+        if obj is None:
+            obj = cls.new(commit=False, **kwargs)
+            is_new = True
+
+        return obj, is_new
+
+
+    @classmethod
+    def get_or_404(cls, obj_id: int) -> BaseModel:
+        obj = cls.query.get(obj_id)
+
+        if obj is None:
+            abort(404, f"{cls.__name__} with ID {obj_id} could not be found!")
+
+        return obj
+
+
     def commit(self):
         db.session.commit()
 
+    def flush(self):
+        db.session.flush()
+
+
 class NamedBaseModel(BaseModel):
     __abstract__ = True
 
     name = db.Column(db.String, nullable=False)
 
+    serialize_only = BaseModel.serialize_only + ("name",)
+
     @commit_on_return
     def set_name(self, name: str):
         self.name = name