6
0
Bläddra i källkod

updated object deletion (especially Project)

Dimitri Korsch 4 år sedan
förälder
incheckning
ed5dc9ccdf

+ 1 - 1
labels/flat_moth_label_provider/moth_labels.py

@@ -20,7 +20,7 @@ class FlatMothLabels(LabelProvider):
         labels = []
         for key, entries in self.mapping.items():
             display_name = f"{entries['genus']} {entries['species']} ({entries['kr']})"
-            print(key, display_name)
+            #print(key, display_name)
             label = self.create_label(key, display_name)
             labels.append(label)
 

+ 3 - 1
pycs/database/Collection.py

@@ -8,7 +8,9 @@ class Collection(NamedBaseModel):
 
     # table columns
     project_id = db.Column(
-        db.Integer, db.ForeignKey("project.id", ondelete="CASCADE"), nullable=False)
+        db.Integer,
+        db.ForeignKey("project.id", ondelete="CASCADE"),
+        nullable=False)
 
     reference = db.Column(
         db.String, nullable=False)

+ 6 - 5
pycs/database/Database.py

@@ -33,11 +33,11 @@ class Database:
             LabelProvider.discover("labels/")
 
     def __enter__(self):
-        app.logger.warning("Database.__enter__(): REMOVE ME!")
-        return self
+        # app.logger.warning("Database.__enter__(): REMOVE ME!")
+        return db
 
     def __exit__(self, exc_type, exc_val, exc_tb):
-        app.logger.warning("Database.__exit__(): REMOVE ME!")
+        # app.logger.warning("Database.__exit__(): REMOVE ME!")
 
         if exc_type is None:
             db.session.commit()
@@ -139,7 +139,8 @@ class Database:
                        label_provider: Optional[LabelProvider],
                        root_folder: str,
                        external_data: bool,
-                       data_folder: str):
+                       data_folder: str,
+                       commit: bool = True):
         """
         insert a project into the database
 
@@ -154,7 +155,7 @@ class Database:
         """
         # prepare some values
 
-        return Project.new(commit=True,
+        return Project.new(commit=commit,
             name=name,
             description=description,
             model=model,

+ 1 - 1
pycs/database/File.py

@@ -49,7 +49,7 @@ class File(NamedBaseModel):
 
 
     # relationships to other models
-    results = db.relationship("Result", backref="file", lazy="dynamic")
+    results = db.relationship("Result", backref="file", lazy="dynamic", passive_deletes=True)
 
     serialize_only = (
         "id", "name", "uuid",

+ 5 - 1
pycs/database/Label.py

@@ -43,7 +43,11 @@ class Label(NamedBaseModel):
     )
 
     # relationships to other models
-    parent = db.relationship("Label", backref="children", remote_side=[id])
+    parent = db.relationship("Label",
+        backref="children",
+        remote_side=[id],
+    )
+
     serialize_only = (
         "id",
         "name",

+ 24 - 8
pycs/database/Project.py

@@ -37,14 +37,30 @@ class Project(NamedBaseModel):
     __table_args__ = ()
 
     # relationships to other models
-    files = db.relationship("File", backref="project", lazy="dynamic")
-    labels = db.relationship("Label", backref="project", lazy="dynamic")
-    collections = db.relationship("Collection", backref="project", lazy="dynamic")
-
-    serialize_rules = (
-        '-files',
-        '-labels',
-        '-collections',
+    files = db.relationship("File",
+        backref="project",
+        lazy="dynamic",
+        passive_deletes=True)
+
+    labels = db.relationship("Label",
+        backref="project",
+        lazy="dynamic",
+        passive_deletes=True)
+
+    collections = db.relationship("Collection",
+        backref="project",
+        lazy="dynamic",
+        passive_deletes=True)
+
+    serialize_only = (
+        "id",
+        "name",
+        "description",
+        "model_id",
+        "label_provider_id",
+        "root_folder",
+        "external_data",
+        "data_folder",
     )
 
     def label(self, id: int) -> T.Optional[Label]:

+ 1 - 1
pycs/frontend/endpoints/data/RemoveFile.py

@@ -46,5 +46,5 @@ class RemoveFile(View):
             # TODO remove temp files
 
         # send notification
-        self.nm.remove_file(file.id)
+        self.nm.remove_file(file.serialize())
         return make_response()

+ 2 - 2
pycs/frontend/endpoints/labels/RemoveLabel.py

@@ -35,7 +35,7 @@ class RemoveLabel(View):
             abort(404)
 
         # find children
-        children = label.children()
+        children = label.children
 
         # start transaction
         with self.db:
@@ -46,7 +46,7 @@ class RemoveLabel(View):
 
             # remove label
             label.remove()
-            self.nm.remove_label(label.id)
+            self.nm.remove_label(label.serialize())
 
         # return success response
         return make_response()

+ 12 - 9
pycs/frontend/endpoints/projects/CreateProject.py

@@ -8,6 +8,7 @@ from flask import make_response, request, abort
 from flask.views import View
 
 from pycs.database.Database import Database
+from pycs.database.Project import Project
 from pycs.frontend.endpoints.projects.ExecuteExternalStorage import ExecuteExternalStorage
 from pycs.frontend.endpoints.projects.ExecuteLabelProvider import ExecuteLabelProvider
 from pycs.frontend.notifications.NotificationManager import NotificationManager
@@ -85,42 +86,44 @@ class CreateProject(View):
             model, _ = model.copy_to(f'{model.name} ({name})', model_folder)
 
             # create entry in database
-            created = self.db.create_project(name, description, model, label_provider,
+            project = self.db.create_project(name, description, model, label_provider,
                                              project_folder, external_data, data_folder)
 
         # execute label provider and add labels to project
         if label_provider is not None:
-            ExecuteLabelProvider.execute_label_provider(self.db, self.nm, self.jobs, created,
+            ExecuteLabelProvider.execute_label_provider(self.db, self.nm, self.jobs, project,
                                                         label_provider)
 
         # load model and add collections to the project
         def load_model_and_get_collections():
             with closing(load_pipeline(model.root_folder)) as pipeline:
-                return pipeline.collections.all()
+                return pipeline.collections()
 
+        project_id = project.id
         def add_collections_to_project(provided_collections):
+            project = Project.query.get(project_id)
             with self.db:
                 for position, collection in enumerate(provided_collections):
-                    created.create_collection(collection['reference'],
+                    project.create_collection(collection['reference'],
                                               collection['name'],
                                               collection['description'],
                                               position + 1,
                                               collection['autoselect'])
 
-        self.jobs.run(created,
+        self.jobs.run(project,
                       'Media Collections',
-                      f'{created.name}',
-                      f'{created.id}/media-collections',
+                      f'{project.name}',
+                      f'{project.id}/media-collections',
                       executable=load_model_and_get_collections,
                       result=add_collections_to_project)
 
         # find media files
         if external_data:
-            ExecuteExternalStorage.find_media_files(self.db, self.nm, self.jobs, created)
+            ExecuteExternalStorage.find_media_files(self.db, self.nm, self.jobs, project)
 
         # fire event
         self.nm.create_model(model.id)
-        self.nm.create_project(created.id)
+        self.nm.create_project(project.id)
 
         # return success response
         return make_response()

+ 3 - 2
pycs/frontend/endpoints/projects/ExecuteLabelProvider.py

@@ -72,12 +72,13 @@ class ExecuteLabelProvider(View):
                 provided_labels = label_provider_impl.get_labels()
                 return provided_labels
 
+        project_id = project.id
         # result adds the received labels to the database and fires events
         def result(provided_labels):
+            project = Project.query.get(project_id)
             with db:
                 for label in provided_labels:
-                    created_label, insert = project.create_label(label['name'], label['id'],
-                                                                 label['parent'])
+                    created_label, insert = project.create_label(**label)
 
                     if insert:
                         nm.create_label(created_label.id)

+ 12 - 6
pycs/frontend/endpoints/projects/RemoveProject.py

@@ -24,7 +24,7 @@ class RemoveProject(View):
         # extract request data
         data = request.get_json(force=True)
 
-        if 'remove' not in data or data['remove'] is not True:
+        if not data.get('remove', False):
             abort(400)
 
         # start transaction
@@ -38,14 +38,20 @@ class RemoveProject(View):
             model = project.model
             model.remove(commit=False)
 
-            # remove from database
+            # ON_DELETE=cascade does not seem to work...
+            # hence, this 3 delete lines here
+            project.files.delete()
+            project.labels.delete()
+            project.collections.delete()
+
+            # remove project from database
             project.remove(commit=False)
 
+            # send update
+            self.nm.remove_model(model.serialize())
+            self.nm.remove_project(project.serialize())
+
             # remove from file system
             shutil.rmtree(project.root_folder)
 
-            # send update
-            self.nm.remove_model(model.id)
-            self.nm.remove_project(project.id)
-
             return make_response()

+ 1 - 1
pycs/frontend/endpoints/results/CreateResult.py

@@ -53,7 +53,7 @@ class CreateResult(View):
             for result in file.results.all():
                 if result.type == 'labeled-image':
                     result.remove()
-                    self.nm.remove_result(result.id)
+                    self.nm.remove_result(result.serialize())
 
             # insert into database
             result = file.create_result('user', rtype, label, data)

+ 1 - 1
pycs/frontend/endpoints/results/RemoveResult.py

@@ -33,5 +33,5 @@ class RemoveResult(View):
         with self.db:
             result.remove()
 
-        self.nm.remove_result(result.id)
+        self.nm.remove_result(result.serialize())
         return make_response()

+ 1 - 1
pycs/frontend/endpoints/results/ResetResults.py

@@ -38,6 +38,6 @@ class ResetResults(View):
                 result.remove()
 
         for result in results:
-            self.nm.remove_result(result.id)
+            self.nm.remove_result(result.serialize())
 
         return make_response()

+ 12 - 11
pycs/frontend/notifications/NotificationManager.py

@@ -27,7 +27,8 @@ class NotificationManager:
         else:
             obj = obj_id
 
-        app.logger.debug(name, obj)
+        assert obj is not None, "Object was unexpectedly None!"
+        app.logger.debug(f"{name}: {obj}")
 
         enc = self.json.default(obj)
         self.sio.emit(name, enc)
@@ -68,14 +69,14 @@ class NotificationManager:
         """
         self.__emit('create-model', created_model_id, Model)
 
-    def remove_model(self, removed_model_id: int):
+    def remove_model(self, model_serialized: dict):
         """
         fire remove-model event
 
         :param removed_model:
         :return:
         """
-        self.__emit('remove-model', removed_model_id, Model)
+        self.__emit('remove-model', model_serialized)
 
     def create_project(self, created_project_id: int):
         """
@@ -86,14 +87,14 @@ class NotificationManager:
         """
         self.__emit('create-project', created_project_id, Project)
 
-    def remove_project(self, removed_project_id: int):
+    def remove_project(self, project_serialized: dict):
         """
         fire remove-project event
 
         :param removed_project:
         :return:
         """
-        self.__emit('remove-project', removed_project_id, Project)
+        self.__emit('remove-project', project_serialized)
 
     def edit_project(self, edited_project_id: int):
         """
@@ -122,14 +123,14 @@ class NotificationManager:
         """
         self.__emit('edit-label', edited_label_id, Label)
 
-    def remove_label(self, removed_label_id: int):
+    def remove_label(self, label_serialized: dict):
         """
         fire remove-label event
 
         :param removed_label:
         :return:
         """
-        self.__emit('remove-label', removed_label_id, Label)
+        self.__emit('remove-label', label_serialized)
 
     def create_file(self, created_file_id: int):
         """
@@ -149,14 +150,14 @@ class NotificationManager:
         """
         self.__emit('edit-file', edited_file_id, File)
 
-    def remove_file(self, removed_file_id: int):
+    def remove_file(self, file_serialized: dict):
         """
         fire remove-file event
 
         :param removed_file:
         :return:
         """
-        self.__emit('remove-file', removed_file_id, File)
+        self.__emit('remove-file', file_serialized)
 
     def create_result(self, created_result_id: int):
         """
@@ -176,11 +177,11 @@ class NotificationManager:
         """
         self.__emit('edit-result', edited_result_id, Result)
 
-    def remove_result(self, removed_result_id: int):
+    def remove_result(self, result_serialized: dict):
         """
         fire remove-result event
 
         :param removed_result:
         :return:
         """
-        self.__emit('remove-result', removed_result_id, Result)
+        self.__emit('remove-result', result_serialized)

+ 4 - 5
pycs/frontend/util/JSONEncoder.py

@@ -25,9 +25,8 @@ class JSONEncoder(Base):
         if isinstance(o, datetime.datetime):
             return str(o)
 
-        try:
-            return o.__dict__
-        except:
-            import pdb; pdb.set_trace()
-            raise
+        if isinstance(o, dict):
+            return o
+
+        return o.__dict__
 

+ 5 - 5
pycs/interfaces/LabelProvider.py

@@ -32,17 +32,17 @@ class LabelProvider:
         raise NotImplementedError
 
     @staticmethod
-    def create_label(identifier, name, parent_identifier=None):
+    def create_label(reference, name, parent_id=None):
         """
         create a label result
 
-        :param identifier: label identifier
+        :param reference: label reference
         :param name: label name
-        :param parent_identifier: parent's identifier
+        :param parent_id: parent's identifier
         :return:
         """
         return {
-            'id': identifier,
             'name': name,
-            'parent': parent_identifier
+            'reference': reference,
+            'parent_id': parent_id
         }

+ 1 - 1
pycs/interfaces/MediaFile.py

@@ -88,7 +88,7 @@ class MediaFile:
         """
         removed = self.__file.remove_results(origin='pipeline')
         for r in removed:
-            self.__notifications.add(self.__notifications.nm.remove_result, r.id)
+            self.__notifications.add(self.__notifications.nm.remove_result, r.serialize())
 
     def __get_results(self, origin: str) -> List[Union[MediaImageLabel, MediaBoundingBox]]:
         def map_r(result: Result) -> Union[MediaImageLabel, MediaBoundingBox]:

+ 4 - 2
test/test_database.py

@@ -158,6 +158,7 @@ class TestDatabase(unittest.TestCase):
 
     def test_no_labels_after_project_deletion(self):
 
+        self.assertEqual(0, Label.query.count())
         project = self.database.project(1)
         for i in range(5):
             label, is_new = project.create_label(
@@ -169,10 +170,11 @@ class TestDatabase(unittest.TestCase):
             self.assertIsNotNone(label)
 
         self.assertEqual(5, Label.query.filter_by(project_id=project.id).count())
+        with self.database:
+            project.remove(commit=False)
 
-        project.remove()
         self.assertIsNone(self.database.project(1))
-        self.assertEqual(0, Label.query.filter_by(project_id=project.id).count())
+        self.assertEqual(0, Label.query.count())
 
 if __name__ == '__main__':
     unittest.main()