Browse Source

removed Database usage in the remaining Views. Fixed tests

Dimitri Korsch 3 years ago
parent
commit
ad5124a2d3
33 changed files with 14870 additions and 375 deletions
  1. 1 1
      pycs/database/File.py
  2. 5 3
      pycs/database/Label.py
  3. 14 8
      pycs/database/Project.py
  4. 23 23
      pycs/frontend/WebServer.py
  5. 2 7
      pycs/frontend/endpoints/ListLabelProviders.py
  6. 2 7
      pycs/frontend/endpoints/ListModels.py
  7. 8 11
      pycs/frontend/endpoints/data/GetFile.py
  8. 2 6
      pycs/frontend/endpoints/data/GetPreviousAndNextFile.py
  9. 100 102
      pycs/frontend/endpoints/data/GetResizedFile.py
  10. 14 17
      pycs/frontend/endpoints/data/RemoveFile.py
  11. 5 8
      pycs/frontend/endpoints/data/UploadFile.py
  12. 7 9
      pycs/frontend/endpoints/labels/CreateLabel.py
  13. 5 13
      pycs/frontend/endpoints/labels/EditLabelParent.py
  14. 2 6
      pycs/frontend/endpoints/labels/ListLabels.py
  15. 5 5
      pycs/frontend/endpoints/labels/RemoveLabel.py
  16. 47 40
      pycs/frontend/endpoints/projects/CreateProject.py
  17. 1 1
      pycs/frontend/endpoints/projects/EditProjectDescription.py
  18. 1 1
      pycs/frontend/endpoints/projects/EditProjectName.py
  19. 1 1
      pycs/frontend/endpoints/projects/ExecuteExternalStorage.py
  20. 13 8
      pycs/frontend/endpoints/projects/ExecuteLabelProvider.py
  21. 2 6
      pycs/frontend/endpoints/projects/GetProjectModel.py
  22. 5 12
      pycs/frontend/endpoints/projects/ListCollections.py
  23. 17 19
      pycs/frontend/endpoints/projects/ListFiles.py
  24. 4 4
      pycs/frontend/endpoints/projects/RemoveProject.py
  25. 6 9
      pycs/frontend/endpoints/results/ConfirmResult.py
  26. 22 14
      pycs/frontend/endpoints/results/CreateResult.py
  27. 5 6
      pycs/frontend/endpoints/results/EditResultData.py
  28. 6 7
      pycs/frontend/endpoints/results/EditResultLabel.py
  29. 2 6
      pycs/frontend/endpoints/results/GetResults.py
  30. 5 5
      pycs/frontend/endpoints/results/RemoveResult.py
  31. 10 8
      pycs/frontend/endpoints/results/ResetResults.py
  32. 2 1
      pycs/frontend/notifications/NotificationManager.py
  33. 14526 1
      webui/package-lock.json

+ 1 - 1
pycs/database/File.py

@@ -140,7 +140,7 @@ class File(NamedBaseModel):
         return self.results.get(id)
 
 
-    def create_result(self, origin, result_type, label, data: T.Optional[dict] = None):
+    def create_result(self, origin, result_type, label, data: T.Optional[dict] = None, commit: bool = True):
         result = Result.new(commit=False,
                             file_id=self.id,
                             origin=origin,

+ 5 - 3
pycs/database/Label.py

@@ -10,7 +10,7 @@ def compare_children(start_label: Label, id: int):
 
     labels_to_check = [start_label]
 
-    while labels_to_check:
+    while id is not None and labels_to_check:
         label = labels_to_check.pop(0)
 
         if label.id == id:
@@ -62,7 +62,7 @@ class Label(NamedBaseModel):
         "reference",
     )
 
-    def set_parent(self, parent_id: int):
+    def set_parent(self, parent_id: int, commit: bool = True):
         """
         set this labels parent
 
@@ -73,4 +73,6 @@ class Label(NamedBaseModel):
             raise ValueError('Cyclic relationship detected!')
 
         self.parent_id = parent_id
-        self.commit()
+
+        if commit:
+            self.commit()

+ 14 - 8
pycs/database/Project.py

@@ -99,7 +99,7 @@ class Project(NamedBaseModel):
         return self.collections.filter_by(reference=reference).one_or_none()
 
     def create_label(self, name: str, reference: str = None,
-                     parent_id: int = None) -> T.Tuple[T.Optional[Label], bool]:
+                     parent_id: int = None, commit: bool = True) -> T.Tuple[T.Optional[Label], bool]:
         """
         create a label for this project. If there is already a label with the same reference
         in the database its name is updated.
@@ -112,10 +112,11 @@ class Project(NamedBaseModel):
 
         label, is_new = Label.get_or_create(project=self, reference=reference)
 
-        label.set_name(name)
-        label.set_parent(parent_id)
+        label.name = name
+        label.set_parent(parent_id, commit=False)
 
-        self.commit()
+        if commit:
+            self.commit()
 
         return label, is_new
 
@@ -124,19 +125,22 @@ class Project(NamedBaseModel):
                           name: str,
                           description: str,
                           position: int,
-                          autoselect: bool):
+                          autoselect: bool,
+                          commit: bool = True):
 
         collection, is_new = Collection.get_or_create(project=self, reference=reference)
         collection.name = name
         collection.description = description
         collection.position = position
         collection.autoselect = autoselect
-        self.commit()
+
+        if commit:
+            self.commit()
 
         return collection, is_new
 
     def add_file(self, uuid: str, file_type: str, name: str, extension: str, size: int,
-                 filename: str, frames: int = None, fps: float = None) -> T.Tuple[File, bool]:
+                 filename: str, frames: int = None, fps: float = None, commit: bool = True) -> T.Tuple[File, bool]:
         """
         add a file to this project
 
@@ -162,7 +166,9 @@ class Project(NamedBaseModel):
         file.frames = frames
         file.fps = fps
 
-        self.commit()
+        if commit:
+            self.commit()
+
         return file, is_new
 
 

+ 23 - 23
pycs/frontend/WebServer.py

@@ -189,73 +189,73 @@ class WebServer:
         # models
         self.app.add_url_rule(
             '/models',
-            view_func=ListModels.as_view('list_models', self.db)
+            view_func=ListModels.as_view('list_models')
         )
         self.app.add_url_rule(
             '/projects/<int:identifier>/model',
-            view_func=GetProjectModel.as_view('get_project_model', self.db)
+            view_func=GetProjectModel.as_view('get_project_model')
         )
 
         # labels
         self.app.add_url_rule(
             '/label_providers',
-            view_func=ListLabelProviders.as_view('label_providers', self.db)
+            view_func=ListLabelProviders.as_view('label_providers')
         )
         self.app.add_url_rule(
             '/projects/<int:identifier>/labels',
-            view_func=ListLabels.as_view('list_labels', self.db)
+            view_func=ListLabels.as_view('list_labels')
         )
         self.app.add_url_rule(
             '/projects/<int:identifier>/labels',
-            view_func=CreateLabel.as_view('create_label', self.db, self.nm)
+            view_func=CreateLabel.as_view('create_label', self.nm)
         )
         self.app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/remove',
-            view_func=RemoveLabel.as_view('remove_label', self.db, self.nm)
+            view_func=RemoveLabel.as_view('remove_label', self.nm)
         )
         self.app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/name',
-            view_func=EditLabelName.as_view('edit_label_name', self.db, self.nm)
+            view_func=EditLabelName.as_view('edit_label_name', self.nm)
         )
         self.app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/parent',
-            view_func=EditLabelParent.as_view('edit_label_parent', self.db, self.nm)
+            view_func=EditLabelParent.as_view('edit_label_parent', self.nm)
         )
 
         # collections
         self.app.add_url_rule(
             '/projects/<int:project_id>/collections',
-            view_func=ListCollections.as_view('list_collections', self.db)
+            view_func=ListCollections.as_view('list_collections')
         )
         self.app.add_url_rule(
             '/projects/<int:project_id>/data/<int:collection_id>/<int:start>/<int:length>',
-            view_func=ListFiles.as_view('list_collection_files', self.db)
+            view_func=ListFiles.as_view('list_collection_files')
         )
 
         # data
         self.app.add_url_rule(
             '/projects/<int:identifier>/data',
-            view_func=UploadFile.as_view('upload_file', self.db, self.nm)
+            view_func=UploadFile.as_view('upload_file', self.nm)
         )
         self.app.add_url_rule(
             '/projects/<int:project_id>/data/<int:start>/<int:length>',
-            view_func=ListFiles.as_view('list_files', self.db)
+            view_func=ListFiles.as_view('list_files')
         )
         self.app.add_url_rule(
             '/data/<int:identifier>/remove',
-            view_func=RemoveFile.as_view('remove_file', self.db, self.nm)
+            view_func=RemoveFile.as_view('remove_file', self.nm)
         )
         self.app.add_url_rule(
             '/data/<int:file_id>',
-            view_func=GetFile.as_view('get_file', self.db)
+            view_func=GetFile.as_view('get_file')
         )
         self.app.add_url_rule(
             '/data/<int:file_id>/<resolution>',
-            view_func=GetResizedFile.as_view('get_resized_file', self.db)
+            view_func=GetResizedFile.as_view('get_resized_file')
         )
         self.app.add_url_rule(
             '/data/<int:file_id>/previous_next',
-            view_func=GetPreviousAndNextFile.as_view('get_previous_and_next_file', self.db)
+            view_func=GetPreviousAndNextFile.as_view('get_previous_and_next_file')
         )
 
         # results
@@ -265,31 +265,31 @@ class WebServer:
         )
         self.app.add_url_rule(
             '/data/<int:file_id>/results',
-            view_func=GetResults.as_view('get_results', self.db)
+            view_func=GetResults.as_view('get_results')
         )
         self.app.add_url_rule(
             '/data/<int:file_id>/results',
-            view_func=CreateResult.as_view('create_result', self.db, self.nm)
+            view_func=CreateResult.as_view('create_result', self.nm)
         )
         self.app.add_url_rule(
             '/data/<int:file_id>/reset',
-            view_func=ResetResults.as_view('reset_results', self.db, self.nm)
+            view_func=ResetResults.as_view('reset_results', self.nm)
         )
         self.app.add_url_rule(
             '/results/<int:result_id>/remove',
-            view_func=RemoveResult.as_view('remove_result', self.db, self.nm)
+            view_func=RemoveResult.as_view('remove_result', self.nm)
         )
         self.app.add_url_rule(
             '/results/<int:result_id>/confirm',
-            view_func=ConfirmResult.as_view('confirm_result', self.db, self.nm)
+            view_func=ConfirmResult.as_view('confirm_result', self.nm)
         )
         self.app.add_url_rule(
             '/results/<int:result_id>/label',
-            view_func=EditResultLabel.as_view('edit_result_label', self.db, self.nm)
+            view_func=EditResultLabel.as_view('edit_result_label', self.nm)
         )
         self.app.add_url_rule(
             '/results/<int:result_id>/data',
-            view_func=EditResultData.as_view('edit_result_data', self.db, self.nm)
+            view_func=EditResultData.as_view('edit_result_data', self.nm)
         )
 
         # projects

+ 2 - 7
pycs/frontend/endpoints/ListLabelProviders.py

@@ -1,7 +1,7 @@
 from flask import jsonify
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.LabelProvider import LabelProvider
 
 
 class ListLabelProviders(View):
@@ -11,10 +11,5 @@ class ListLabelProviders(View):
     # pylint: disable=arguments-differ
     methods = ['GET']
 
-    def __init__(self, db: Database):
-        # pylint: disable=invalid-name
-        self.db = db
-
     def dispatch_request(self):
-        label_providers = list(self.db.label_providers())
-        return jsonify(label_providers)
+        return jsonify(LabelProviders.query.all())

+ 2 - 7
pycs/frontend/endpoints/ListModels.py

@@ -1,7 +1,7 @@
 from flask import jsonify
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.Model import Model
 
 
 class ListModels(View):
@@ -11,10 +11,5 @@ class ListModels(View):
     # pylint: disable=arguments-differ
     methods = ['GET']
 
-    def __init__(self, db: Database):
-        # pylint: disable=invalid-name
-        self.db = db
-
     def dispatch_request(self):
-        models = list(self.db.models())
-        return jsonify(models)
+        return jsonify(Model.query.all())

+ 8 - 11
pycs/frontend/endpoints/data/GetFile.py

@@ -1,9 +1,10 @@
-from os import path, getcwd
+import os
 
-from flask import abort, send_from_directory
+from flask import abort
+from flask import send_from_directory
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.File import File
 
 
 class GetFile(View):
@@ -13,23 +14,19 @@ class GetFile(View):
     # pylint: disable=arguments-differ
     methods = ['GET']
 
-    def __init__(self, db: Database):
-        # pylint: disable=invalid-name
-        self.db = db
-
     def dispatch_request(self, file_id: int):
         # get file from database
-        file = self.db.file(file_id)
+        file = File.query.get(file_id)
 
         if file is None:
             return abort(404)
 
         # get absolute path
-        if path.isabs(file.path):
+        if os.path.isabs(file.path):
             abs_file_path = file.path
         else:
-            abs_file_path = path.join(getcwd(), file.path)
+            abs_file_path = os.path.join(os.getcwd(), file.path)
 
         # return data
-        file_directory, file_name = path.split(abs_file_path)
+        file_directory, file_name = os.path.split(abs_file_path)
         return send_from_directory(file_directory, file_name)

+ 2 - 6
pycs/frontend/endpoints/data/GetPreviousAndNextFile.py

@@ -1,7 +1,7 @@
 from flask import abort, jsonify
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.File import File
 
 
 class GetPreviousAndNextFile(View):
@@ -11,13 +11,9 @@ class GetPreviousAndNextFile(View):
     # pylint: disable=arguments-differ
     methods = ['GET']
 
-    def __init__(self, db: Database):
-        # pylint: disable=invalid-name
-        self.db = db
-
     def dispatch_request(self, file_id: int):
         # get file from database
-        file = self.db.file(file_id)
+        file = File.query.get(file_id)
 
         if file is None:
             return abort(404)

+ 100 - 102
pycs/frontend/endpoints/data/GetResizedFile.py

@@ -1,13 +1,15 @@
+import cv2
+import os
 import re
-from os import path, getcwd
 
-import cv2
 from PIL import Image
 from eventlet import tpool
-from flask import abort, send_from_directory
+from flask import abort
+from flask import send_from_directory
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.File import File
+from pycs.database.Project import Project
 
 
 class GetResizedFile(View):
@@ -17,17 +19,13 @@ class GetResizedFile(View):
     # pylint: disable=arguments-differ
     methods = ['GET']
 
-    def __init__(self, db: Database):
-        # pylint: disable=invalid-name
-        self.db = db
-
     def dispatch_request(self, file_id: int, resolution: str):
         # get file from database
-        file = self.db.file(file_id)
+        file = File.query.get(file_id)
         if file is None:
             return abort(404, "file object not found")
 
-        if not path.exists(file.path):
+        if not os.path.exists(file.path):
             return abort(404, "image not found!")
 
         project = file.project
@@ -38,103 +36,103 @@ class GetResizedFile(View):
         max_height = int(resolution[1]) if len(resolution) > 1 else 2 ** 24
 
         # send data
-        file_directory, file_name = tpool.execute(self.resize_file,
-                                                  project, file, max_width, max_height)
+        file_directory, file_name = tpool.execute(resize_file,
+                                                  project.id, file.id, max_width, max_height)
         return send_from_directory(file_directory, file_name)
 
-    @staticmethod
-    def resize_file(project, file, max_width, max_height):
-        """
-        If file type equals video this function extracts a thumbnail first. It calls resize_image
-        to resize and returns the resized files directory and name.
-
-        :param project: associated project
-        :param file: file object
-        :param max_width: maximum image or thumbnail width
-        :param max_height: maximum image or thumbnail height
-        :return: resized file directory, resized file name
-        """
-        # get absolute path
-        if path.isabs(file.path):
-            abs_file_path = file.path
-        else:
-            abs_file_path = path.join(getcwd(), file.path)
-
-        # extract video thumbnail
-        if file.type == 'video':
-            abs_target_path = path.join(getcwd(), project.root_folder, 'temp', f'{file.uuid}.jpg')
-            GetResizedFile.create_thumbnail(abs_file_path, abs_target_path)
-
-            abs_file_path = abs_target_path
-
-        # resize image file
-        abs_target_path = path.join(getcwd(), project.root_folder,
-                                    'temp', f'{file.uuid}_{max_width}_{max_height}.jpg')
-        result = GetResizedFile.resize_image(abs_file_path, abs_target_path, max_width, max_height)
-
-        # return path
-        if result is not None:
-            return path.split(abs_target_path)
-
-        return path.split(abs_file_path)
-
-    @staticmethod
-    def resize_image(file_path, target_path, max_width, max_height):
-        """
-        resize an image so width < max_width and height < max_height
-
-        :param file_path: path to source file
-        :param target_path: path to target file
-        :param max_width: maximum image width
-        :param max_height: maximum image height
-        :return:
-        """
-        # return if file exists
-        if path.exists(target_path):
-            return True
-
-        # load full size image
-        image = Image.open(file_path)
-        img_width, img_height = image.size
-
-        # abort if file is smaller than desired
-        if img_width < max_width and img_height < max_height:
-            return None
-
-        # calculate target size
-        target_width = int(max_width)
-        target_height = int(max_width * img_height / img_width)
-
-        if target_height > max_height:
-            target_height = int(max_height)
-            target_width = int(max_height * img_width / img_height)
-
-        # resize image
-        resized_image = image.resize((target_width, target_height))
-
-        # save to file
-        resized_image.save(target_path, quality=80)
+def resize_file(project_id, file_id, max_width, max_height):
+    """
+    If file type equals video this function extracts a thumbnail first. It calls resize_image
+    to resize and returns the resized files directory and name.
+
+    :param project: associated project
+    :param file: file object
+    :param max_width: maximum image or thumbnail width
+    :param max_height: maximum image or thumbnail height
+    :return: resized file directory, resized file name
+    """
+    file = File.query.get(file_id)
+    project = Project.query.get(project_id)
+
+    # get absolute path
+    if os.path.isabs(file.path):
+        abs_file_path = file.path
+    else:
+        abs_file_path = os.path.join(getcwd(), file.path)
+
+    # extract video thumbnail
+    if file.type == 'video':
+        abs_target_path = os.path.join(os.getcwd(), project.root_folder, 'temp', f'{file.uuid}.jpg')
+        create_thumbnail(abs_file_path, abs_target_path)
+
+        abs_file_path = abs_target_path
+
+    # resize image file
+    abs_target_path = os.path.join(os.getcwd(), project.root_folder,
+                                'temp', f'{file.uuid}_{max_width}_{max_height}.jpg')
+    result = resize_image(abs_file_path, abs_target_path, max_width, max_height)
+
+    # return path
+    if result is not None:
+        return os.path.split(abs_target_path)
+
+    return os.path.split(abs_file_path)
+
+def resize_image(file_path, target_path, max_width, max_height):
+    """
+    resize an image so width < max_width and height < max_height
+
+    :param file_path: path to source file
+    :param target_path: path to target file
+    :param max_width: maximum image width
+    :param max_height: maximum image height
+    :return:
+    """
+    # return if file exists
+    if os.path.exists(target_path):
         return True
 
-    @staticmethod
-    def create_thumbnail(file_path, target_path):
-        """
-        extract a thumbnail from a video
+    # load full size image
+    image = Image.open(file_path)
+    img_width, img_height = image.size
+
+    # abort if file is smaller than desired
+    if img_width < max_width and img_height < max_height:
+        return None
+
+    # calculate target size
+    target_width = int(max_width)
+    target_height = int(max_width * img_height / img_width)
+
+    if target_height > max_height:
+        target_height = int(max_height)
+        target_width = int(max_height * img_width / img_height)
 
-        :param file_path: path to source file
-        :param target_path: path to target file
-        :return:
-        """
-        # return if file exists
-        if path.exists(target_path):
-            return
+    # resize image
+    resized_image = image.resize((target_width, target_height))
+
+    # save to file
+    resized_image.save(target_path, quality=80)
+    return True
+
+def create_thumbnail(file_path, target_path):
+    """
+    extract a thumbnail from a video
+
+    :param file_path: path to source file
+    :param target_path: path to target file
+    :return:
+    """
+    # return if file exists
+    if os.path.exists(target_path):
+        return
 
-        # load video
-        video = cv2.VideoCapture(file_path)
+    # load video
+    video = cv2.VideoCapture(file_path)
 
-        # create thumbnail
-        _, image = video.read()
-        cv2.imwrite(target_path, image)
+    # create thumbnail
+    _, image = video.read()
+    cv2.imwrite(target_path, image)
 
-        # close video file
-        video.release()
+    # close video file
+    video.release()

+ 14 - 17
pycs/frontend/endpoints/data/RemoveFile.py

@@ -3,7 +3,7 @@ from os import remove
 from flask import make_response, request, abort
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.File import File
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -14,34 +14,31 @@ class RemoveFile(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, identifier):
         # extract request data
         data = request.get_json(force=True)
 
-        if 'remove' not in data or data['remove'] is not True:
+        if not data.get('remove', False):
             return abort(400)
 
-        # start transaction
-        with self.db:
-            # find file
-            file = self.db.file(identifier)
-            if file is None:
-                return abort(400)
+        # find file
+        file = File.query.get(identifier)
+        if file is None:
+            return abort(400)
 
-            # check if project uses an external data directory
-            if file.project.external_data:
-                return abort(400)
+        # check if project uses an external data directory
+        if file.project.external_data:
+            return abort(400)
 
-            # remove file from database
-            file.remove()
+        # remove file from database
+        file.remove()
 
-            # remove file from folder
-            remove(file.path)
+        # remove file from folder
+        remove(file.path)
 
             # TODO remove temp files
 

+ 5 - 8
pycs/frontend/endpoints/data/UploadFile.py

@@ -6,7 +6,7 @@ from flask import make_response, request, abort
 from flask.views import View
 from werkzeug import formparser
 
-from pycs.database.Database import Database
+from pycs.database.Project import Project
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.util.FileParser import file_info
 
@@ -18,9 +18,8 @@ class UploadFile(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
         self.data_folder = None
@@ -31,7 +30,7 @@ class UploadFile(View):
 
     def dispatch_request(self, identifier):
         # find project
-        project = self.db.project(identifier)
+        project = Project.query.get(identifier)
 
         if project is None:
             return abort(404, "Project not found")
@@ -59,10 +58,8 @@ class UploadFile(View):
         except ValueError as e:
             return abort(400, str(e))
 
-        # add to project files
-        with self.db:
-            file, _ = project.add_file(self.file_id, ftype, self.file_name, self.file_extension,
-                                       self.file_size, self.file_id, frames, fps)
+        file, _ = project.add_file(self.file_id, ftype, self.file_name, self.file_extension,
+                                   self.file_size, self.file_id, frames, fps)
 
         # send update
         self.nm.create_file(file.id)

+ 7 - 9
pycs/frontend/endpoints/labels/CreateLabel.py

@@ -1,7 +1,8 @@
 from flask import request, abort, make_response
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs import db
+from pycs.database.Project import Project
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,9 +13,8 @@ class CreateLabel(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, identifier):
@@ -25,17 +25,15 @@ class CreateLabel(View):
             abort(400)
 
         name = data['name']
-        parent = data['parent'] if 'parent' in data else None
+        parent = data.get('parent')
 
         # find project
-        project = self.db.project(identifier)
+        project = Project.query.get(identifier)
         if project is None:
             abort(404)
 
-        # start transaction
-        with self.db:
-            # insert label
-            label, _ = project.create_label(name, parent_id=parent)
+        # insert label
+        label, _ = project.create_label(name, parent_id=parent)
 
         # send notification
         self.nm.create_label(label.id)

+ 5 - 13
pycs/frontend/endpoints/labels/EditLabelParent.py

@@ -1,7 +1,7 @@
 from flask import request, abort, make_response
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.Label import Label
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,9 +12,8 @@ class EditLabelParent(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, project_id: int, label_id: int):
@@ -24,20 +23,13 @@ class EditLabelParent(View):
         if 'parent' not in data:
             abort(400)
 
-        # find project
-        project = self.db.project(project_id)
-        if project is None:
-            abort(404)
-
         # find label
-        label = project.label(label_id)
+        label = Label.query.filter_by(id=label_id, project_id=project_id).one_or_none()
         if label is None:
             abort(404)
 
-        # start transaction
-        with self.db:
-            # change parent
-            label.set_parent(data['parent'])
+        # change parent
+        label.set_parent(data['parent'])
 
         # send notification
         self.nm.edit_label(label.id)

+ 2 - 6
pycs/frontend/endpoints/labels/ListLabels.py

@@ -1,7 +1,7 @@
 from flask import abort, jsonify
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.Project import Project
 
 
 class ListLabels(View):
@@ -11,13 +11,9 @@ class ListLabels(View):
     # pylint: disable=arguments-differ
     methods = ['GET']
 
-    def __init__(self, db: Database):
-        # pylint: disable=invalid-name
-        self.db = db
-
     def dispatch_request(self, identifier):
         # find project
-        project = self.db.project(identifier)
+        project = Project.query.get(identifier)
         if project is None:
             abort(404)
 

+ 5 - 5
pycs/frontend/endpoints/labels/RemoveLabel.py

@@ -1,7 +1,8 @@
 from flask import request, abort, make_response
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs import db
+from pycs.database.Project import Project
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,9 +13,8 @@ class RemoveLabel(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, project_id: int, label_id: int):
@@ -25,7 +25,7 @@ class RemoveLabel(View):
             abort(400)
 
         # find project
-        project = self.db.project(project_id)
+        project = Project.query.get(project_id)
         if project is None:
             abort(404)
 
@@ -38,7 +38,7 @@ class RemoveLabel(View):
         children = label.children
 
         # start transaction
-        with self.db:
+        with db.session.begin_nested():
             # remove children's parent entry
             for child in children:
                 child.set_parent(None)

+ 47 - 40
pycs/frontend/endpoints/projects/CreateProject.py

@@ -52,51 +52,55 @@ class CreateProject(View):
         data_folder = data['external']
         external_data = data_folder is not None
 
-        # start transaction
-        with db.session.start():
-            # find model
-            model = Model.query.get(model_id)
+        # find model
+        model = Model.query.get(model_id)
 
-            if model is None:
-                return abort(404, "Model not found")
+        if model is None:
+            return abort(404, "Model not found")
 
-            # find label provider
-            if label_provider_id is None:
-                label_provider = None
-            else:
-                label_provider = LabelProvider.query.get(label_provider_id)
+        # find label provider
+        if label_provider_id is None:
+            label_provider = None
+        else:
+            label_provider = LabelProvider.query.get(label_provider_id)
 
-                if label_provider is None:
-                    return abort(404, "Label provider not found")
+            if label_provider is None:
+                return abort(404, "Label provider not found")
 
-            # create project folder
-            project_folder = Path(self.project_folder, str(uuid1()))
-            project_folder.mkdir(parents=True)
+        # create project folder
+        project_folder = Path(self.project_folder, str(uuid1()))
+        project_folder.mkdir(parents=True)
 
-            temp_folder = project_folder / 'temp'
-            temp_folder.mkdir()
+        temp_folder = project_folder / 'temp'
+        temp_folder.mkdir()
 
-            # check project data directory
-            if external_data:
-                # check if exists
-                if not path.exists(data_folder):
-                    return abort(400, "Data folder does not exist!")
+        # check project data directory
+        if external_data:
+            # check if exists
+            if not path.exists(data_folder):
+                return abort(400, "Data folder does not exist!")
 
-            else:
-                data_folder = project_folder / 'data'
-                data_folder.mkdir()
+        else:
+            data_folder = project_folder / 'data'
+            data_folder.mkdir()
 
 
-            # copy model to project folder
-            model_folder = project_folder / 'model'
-            copytree(model.root_folder, str(model_folder))
+        # copy model to project folder
+        model_folder = project_folder / 'model'
+        copytree(model.root_folder, str(model_folder))
 
-            model, _ = model.copy_to(f'{model.name} ({name})', str(model_folder))
+        model, _ = model.copy_to(f'{model.name} ({name})', str(model_folder))
 
-            # create entry in database
-            project = Project.new(name, description, model, label_provider,
-                                  str(project_folder), external_data,
-                                  str(data_folder))
+        # create entry in database
+        project = Project.new(
+            name=name,
+            description=description,
+            model_id=model.id,
+            label_provider_id=label_provider.id,
+            root_folder=str(project_folder),
+            external_data=external_data,
+            data_folder=str(data_folder)
+        )
 
         # execute label provider and add labels to project
         if label_provider is not None:
@@ -112,13 +116,16 @@ class CreateProject(View):
         project_id = project.id
         def add_collections_to_project(provided_collections):
             project = Project.query.get(project_id)
-            with db.session.start():
+            with db.session.begin_nested():
                 for position, collection in enumerate(provided_collections):
-                    project.create_collection(collection['reference'],
-                                              collection['name'],
-                                              collection['description'],
-                                              position + 1,
-                                              collection['autoselect'])
+                    project.create_collection(
+                        collection['reference'],
+                        collection['name'],
+                        collection['description'],
+                        position + 1,
+                        collection['autoselect'],
+                        commit=False,
+                    )
 
         self.jobs.run(project,
                       'Media Collections',

+ 1 - 1
pycs/frontend/endpoints/projects/EditProjectDescription.py

@@ -24,7 +24,7 @@ class EditProjectDescription(View):
         if 'description' not in data or not data['description']:
             return abort(400)
 
-        with db.session.start()
+        with db.session.begin_nested():
             # find project
             project = Project.query.get(identifier)
             if project is None:

+ 1 - 1
pycs/frontend/endpoints/projects/EditProjectName.py

@@ -25,7 +25,7 @@ class EditProjectName(View):
             return abort(400)
 
         # start transaction
-        with db.session.start():
+        with db.session.begin_nested():
             # find project
             project = Project.query.get(identifier)
             if project is None:

+ 1 - 1
pycs/frontend/endpoints/projects/ExecuteExternalStorage.py

@@ -100,7 +100,7 @@ class ExecuteExternalStorage(View):
 
         # progress inserts elements into the database and fires events
         def progress(elements, current, length):
-            with db.session.start():
+            with db.session.begin_nested():
                 project = Project.query.get(project_id)
                 for ftype, file_name, file_extension, file_size, frames, fps in elements:
                     file, is_new = project.add_file(str(uuid1()), ftype, file_name,

+ 13 - 8
pycs/frontend/endpoints/projects/ExecuteLabelProvider.py

@@ -76,14 +76,19 @@ class ExecuteLabelProvider(View):
         # result adds the received labels to the database and fires events
         def result(provided_labels):
             project = Project.query.get(project_id)
-            with db.session.start():
-                for label in provided_labels:
-                    created_label, is_new = project.create_label(**label)
-
-                    if is_new:
-                        nm.create_label(created_label.id)
-                    else:
-                        nm.edit_label(created_label.id)
+            for label in provided_labels:
+                success = False
+                with db.session.begin_nested():
+                    created_label, is_new = project.create_label(commit=False, **label)
+                    success = True
+
+                if not success:
+                    continue
+
+                if is_new:
+                    nm.create_label(created_label.id)
+                else:
+                    nm.edit_label(created_label.id)
 
         # run job with given functions
         jobs.run(project,

+ 2 - 6
pycs/frontend/endpoints/projects/GetProjectModel.py

@@ -1,7 +1,7 @@
 from flask import abort, jsonify
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.Project import Project
 
 
 class GetProjectModel(View):
@@ -11,13 +11,9 @@ class GetProjectModel(View):
     # pylint: disable=arguments-differ
     methods = ['GET']
 
-    def __init__(self, db: Database):
-        # pylint: disable=invalid-name
-        self.db = db
-
     def dispatch_request(self, identifier):
         # find project
-        project = self.db.project(identifier)
+        project = Project.query.get(identifier)
         if project is None:
             abort(404)
 

+ 5 - 12
pycs/frontend/endpoints/projects/ListCollections.py

@@ -1,7 +1,8 @@
-from flask import abort, jsonify
+from flask import abort
+from flask import jsonify
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.Collection import Collection
 
 
 class ListCollections(View):
@@ -11,19 +12,10 @@ class ListCollections(View):
     # pylint: disable=arguments-differ
     methods = ['GET']
 
-    def __init__(self, db: Database):
-        # pylint: disable=invalid-name
-        self.db = db
-
     def dispatch_request(self, project_id: int):
-        # find project
-        project = self.db.project(project_id)
-
-        if project is None:
-            return abort(404)
 
         # get collection list
-        collections = project.collections.all()
+        collections = Collections.query.filter_by(project_id=project_id).all()
 
         # disable autoselect if there are no elements in the collection
         found = False
@@ -32,6 +24,7 @@ class ListCollections(View):
             if collection.autoselect:
                 if found:
                     collection.autoselect = False
+
                 elif collection.count_files() == 0:
                     collection.autoselect = False
                     found = True

+ 17 - 19
pycs/frontend/endpoints/projects/ListFiles.py

@@ -1,7 +1,8 @@
-from flask import abort, jsonify
+from flask import abort
+from flask import jsonify
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.Project import Project
 
 
 class ListFiles(View):
@@ -11,32 +12,29 @@ class ListFiles(View):
     # pylint: disable=arguments-differ
     methods = ['GET']
 
-    def __init__(self, db: Database):
-        # pylint: disable=invalid-name
-        self.db = db
-
     def dispatch_request(self, project_id: int, start: int, length: int, collection_id: int = None):
         # find project
-        project = self.db.project(project_id)
+        project = Project.query.get(project_id)
         if project is None:
             return abort(404)
 
         # get count and files
-        if collection_id is not None:
-            if collection_id == 0:
-                count = project.count_files_without_collection()
-                files = list(project.files_without_collection(start, length))
-            else:
-                collection = project.collection(collection_id)
-                if collection is None:
-                    return abort(404)
-
-                count = collection.count_files()
-                files = list(collection.files(start, length))
-        else:
+        if collection_id is None:
             count = project.count_files()
             files = list(project.get_files(start, length))
 
+        elif collection_id == 0:
+            count = project.count_files_without_collection()
+            files = list(project.files_without_collection(start, length))
+
+        else:
+            collection = project.collection(collection_id)
+            if collection is None:
+                return abort(404)
+
+            count = collection.count_files()
+            files = list(collection.files(start, length))
+
         # return files
         return jsonify({
             'count': count,

+ 4 - 4
pycs/frontend/endpoints/projects/RemoveProject.py

@@ -24,14 +24,14 @@ class RemoveProject(View):
         data = request.get_json(force=True)
 
         if not data.get('remove', False):
-            abort(400)
+            return abort(400)
 
         # start transaction
-        with db.session.start():
+        with db.session.begin_nested():
             # find project
-            project = Project.query.id(identifier)
+            project = Project.query.get(identifier)
             if project is None:
-                abort(404, "Project not found")
+                return abort(404, "Project not found")
 
             # remove model from database
             model = project.model

+ 6 - 9
pycs/frontend/endpoints/results/ConfirmResult.py

@@ -1,7 +1,8 @@
 from flask import make_response, request, abort
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs import db
+from pycs.database.Result import Result
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,26 +13,22 @@ class ConfirmResult(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, result_id: int):
         # extract request data
         data = request.get_json(force=True)
 
-        if 'confirm' not in data or data['confirm'] is not True:
+        if not data.get('confirm', False):
             return abort(400)
 
         # find result
-        result = self.db.result(result_id)
+        result = Result.query.get(result_id)
         if result is None:
             return abort(404)
 
-        # start transaction
-        with self.db:
-            result.set_origin('user')
-
+        result.set_origin('user')
         self.nm.edit_result(result.id)
         return make_response()

+ 22 - 14
pycs/frontend/endpoints/results/CreateResult.py

@@ -1,7 +1,8 @@
 from flask import request, abort, jsonify
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs import db
+from pycs.database.File import File
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,51 +13,58 @@ class CreateResult(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, file_id: int):
         # extract request data
         request_data = request.get_json(force=True)
 
-        if 'type' not in request_data:
-            return abort(400)
-        if request_data['type'] not in ['labeled-image', 'bounding-box']:
+        if request_data.get('type') not in ['labeled-image', 'bounding-box']:
             return abort(400)
 
         rtype = request_data['type']
 
         if 'label' in request_data and request_data['label']:
             label = request_data['label']
+
         elif request_data['type'] == 'labeled-image':
             return abort(400, "label missing for the labeled-image annotation")
+
         else:
             label = None
 
-        if 'data' in request_data and request_data['data']:
+        if request_data.get('data'):
             data = request_data['data']
+
         elif request_data['type'] == 'bounding-box':
             return abort(400, "data missing for the bounding box annotation")
+
         else:
             data = {}
 
         # find file
-        file = self.db.file(file_id)
+        file = File.query.get(file_id)
         if file is None:
             return abort(404)
 
+
+        removed = []
         # start transaction
-        with self.db:
+        with db.session.begin_nested():
             # find full-image labels and remove them
             for result in file.results.all():
                 if result.type == 'labeled-image':
-                    result.remove()
-                    self.nm.remove_result(result.serialize())
+                    removed.append(result.serialize())
+                    result.remove(commit=True)
 
             # insert into database
-            result = file.create_result('user', rtype, label, data)
-            self.nm.create_result(result.id)
+            new_result = file.create_result('user', rtype, label, data,
+                commit=False)
+
+        for result in removed:
+            self.nm.remove_result(result.serialize())
 
-        return jsonify(result)
+        self.nm.create_result(new_result.id)
+        return jsonify(new_result)

+ 5 - 6
pycs/frontend/endpoints/results/EditResultData.py

@@ -1,7 +1,7 @@
 from flask import make_response, abort
 from flask.views import View, request
 
-from pycs.database.Database import Database
+from pycs.database.Result import Result
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,9 +12,8 @@ class EditResultData(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, result_id: int):
@@ -25,14 +24,14 @@ class EditResultData(View):
             return abort(400)
 
         # find result
-        result = self.db.result(result_id)
+        result = Result.query.get(result_id)
         if result is None:
             return abort(404)
 
         # start transaction and set label
-        with self.db:
+        with db.session.get():
             result.data = data['data']
-            result.set_origin('user')
+            result.origin = 'user'
 
         self.nm.edit_result(result.id)
         return make_response()

+ 6 - 7
pycs/frontend/endpoints/results/EditResultLabel.py

@@ -1,7 +1,7 @@
 from flask import make_response, abort
 from flask.views import View, request
 
-from pycs.database.Database import Database
+from pycs.database.Result import Result
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,9 +12,8 @@ class EditResultLabel(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, result_id: int):
@@ -25,7 +24,7 @@ class EditResultLabel(View):
             return abort(400)
 
         # find result
-        result = self.db.result(result_id)
+        result = Result.query.get(result_id)
         if result is None:
             return abort(404)
 
@@ -34,9 +33,9 @@ class EditResultLabel(View):
             return abort(400)
 
         # start transaction and set label
-        with self.db:
-            result.set_label(data['label'])
-            result.set_origin('user')
+        with db.session.begin_nested():
+            result.label_id = int(data['label'])
+            result.origin = 'user'
 
         self.nm.edit_result(result.id)
         return make_response()

+ 2 - 6
pycs/frontend/endpoints/results/GetResults.py

@@ -1,7 +1,7 @@
 from flask import abort, jsonify
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.File import File
 
 
 class GetResults(View):
@@ -11,13 +11,9 @@ class GetResults(View):
     # pylint: disable=arguments-differ
     methods = ['GET']
 
-    def __init__(self, db: Database):
-        # pylint: disable=invalid-name
-        self.db = db
-
     def dispatch_request(self, file_id: int):
         # get file from database
-        file = self.db.file(file_id)
+        file = File.query.get(file_id)
         if file is None:
             return abort(404)
 

+ 5 - 5
pycs/frontend/endpoints/results/RemoveResult.py

@@ -1,7 +1,8 @@
 from flask import make_response, request, abort
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs import db
+from pycs.database.Result import Result
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,9 +13,8 @@ class RemoveResult(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, result_id: int):
@@ -25,12 +25,12 @@ class RemoveResult(View):
             abort(400)
 
         # find result
-        result = self.db.result(result_id)
+        result = Result.query.get(result_id)
         if result is None:
             return abort(404)
 
         # start transaction
-        with self.db:
+        with db.session.begin_nested():
             result.remove()
 
         self.nm.remove_result(result.serialize())

+ 10 - 8
pycs/frontend/endpoints/results/ResetResults.py

@@ -1,7 +1,8 @@
 from flask import make_response, abort
 from flask.views import View, request
 
-from pycs.database.Database import Database
+from pycs import db
+from pycs.database.File import File
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,20 +13,19 @@ class ResetResults(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, file_id: int):
         # extract request data
         data = request.get_json(force=True)
 
-        if 'reset' not in data or data['reset'] is not True:
+        if not data.get('reset', False):
             abort(400)
 
         # find file
-        file = self.db.file(file_id)
+        file = File.query.get(file_id)
         if file is None:
             return abort(404)
 
@@ -33,11 +33,13 @@ class ResetResults(View):
         results = file.results.all()
 
         # start transaction
-        with self.db:
+        removed = []
+        with db.session.begin_nested():
             for result in results:
+                removed.append(result.serialize())
                 result.remove()
 
-        for result in results:
-            self.nm.remove_result(result.serialize())
+        for result in removed:
+            self.nm.remove_result(result)
 
         return make_response()

+ 2 - 1
pycs/frontend/notifications/NotificationManager.py

@@ -21,7 +21,8 @@ class NotificationManager:
 
     def __emit(self, name, obj_id, cls=None):
         if cls is not None:
-            assert isinstance(obj_id, int), "Object ID must be an integer!"
+            assert isinstance(obj_id, int), \
+                f"{cls.__name__} ID must be an integer, but was {type(obj_id)} ({obj_id=})!"
             obj = cls.query.get(obj_id)
 
         else:

File diff suppressed because it is too large
+ 14526 - 1
webui/package-lock.json


Some files were not shown because too many files changed in this diff