6
0
Переглянути джерело

removed Database class usage in some of the views

Dimitri Korsch 3 роки тому
батько
коміт
b1e575d79c

+ 30 - 31
pycs/frontend/WebServer.py

@@ -115,13 +115,13 @@ class WebServer:
 
     def init_notifications(self):
         # create notification manager
-        self.notifications = n = NotificationManager(self.sio)
+        self.nm = NotificationManager(self.sio)
 
-        self.jobs.on_create(n.create_job)
-        self.jobs.on_start(n.edit_job)
-        self.jobs.on_progress(n.edit_job)
-        self.jobs.on_finish(n.edit_job)
-        self.jobs.on_remove(n.remove_job)
+        self.jobs.on_create(self.nm.create_job)
+        self.jobs.on_start(self.nm.edit_job)
+        self.jobs.on_progress(self.nm.edit_job)
+        self.jobs.on_finish(self.nm.edit_job)
+        self.jobs.on_remove(self.nm.remove_job)
 
     def development_init(self):
 
@@ -207,19 +207,19 @@ class WebServer:
         )
         self.app.add_url_rule(
             '/projects/<int:identifier>/labels',
-            view_func=CreateLabel.as_view('create_label', self.db, self.notifications)
+            view_func=CreateLabel.as_view('create_label', self.db, self.nm)
         )
         self.app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/remove',
-            view_func=RemoveLabel.as_view('remove_label', self.db, self.notifications)
+            view_func=RemoveLabel.as_view('remove_label', self.db, self.nm)
         )
         self.app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/name',
-            view_func=EditLabelName.as_view('edit_label_name', self.db, self.notifications)
+            view_func=EditLabelName.as_view('edit_label_name', self.db, self.nm)
         )
         self.app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/parent',
-            view_func=EditLabelParent.as_view('edit_label_parent', self.db, self.notifications)
+            view_func=EditLabelParent.as_view('edit_label_parent', self.db, self.nm)
         )
 
         # collections
@@ -235,7 +235,7 @@ class WebServer:
         # data
         self.app.add_url_rule(
             '/projects/<int:identifier>/data',
-            view_func=UploadFile.as_view('upload_file', self.db, self.notifications)
+            view_func=UploadFile.as_view('upload_file', self.db, self.nm)
         )
         self.app.add_url_rule(
             '/projects/<int:project_id>/data/<int:start>/<int:length>',
@@ -243,7 +243,7 @@ class WebServer:
         )
         self.app.add_url_rule(
             '/data/<int:identifier>/remove',
-            view_func=RemoveFile.as_view('remove_file', self.db, self.notifications)
+            view_func=RemoveFile.as_view('remove_file', self.db, self.nm)
         )
         self.app.add_url_rule(
             '/data/<int:file_id>',
@@ -269,75 +269,74 @@ class WebServer:
         )
         self.app.add_url_rule(
             '/data/<int:file_id>/results',
-            view_func=CreateResult.as_view('create_result', self.db, self.notifications)
+            view_func=CreateResult.as_view('create_result', self.db, self.nm)
         )
         self.app.add_url_rule(
             '/data/<int:file_id>/reset',
-            view_func=ResetResults.as_view('reset_results', self.db, self.notifications)
+            view_func=ResetResults.as_view('reset_results', self.db, self.nm)
         )
         self.app.add_url_rule(
             '/results/<int:result_id>/remove',
-            view_func=RemoveResult.as_view('remove_result', self.db, self.notifications)
+            view_func=RemoveResult.as_view('remove_result', self.db, self.nm)
         )
         self.app.add_url_rule(
             '/results/<int:result_id>/confirm',
-            view_func=ConfirmResult.as_view('confirm_result', self.db, self.notifications)
+            view_func=ConfirmResult.as_view('confirm_result', self.db, self.nm)
         )
         self.app.add_url_rule(
             '/results/<int:result_id>/label',
-            view_func=EditResultLabel.as_view('edit_result_label', self.db, self.notifications)
+            view_func=EditResultLabel.as_view('edit_result_label', self.db, self.nm)
         )
         self.app.add_url_rule(
             '/results/<int:result_id>/data',
-            view_func=EditResultData.as_view('edit_result_data', self.db, self.notifications)
+            view_func=EditResultData.as_view('edit_result_data', self.db, self.nm)
         )
 
         # projects
         self.app.add_url_rule(
             '/projects',
-            view_func=ListProjects.as_view('list_projects', self.db)
+            view_func=ListProjects.as_view('list_projects')
         )
         self.app.add_url_rule(
             '/projects',
-            view_func=CreateProject.as_view('create_project', self.db, self.notifications, self.jobs)
+            view_func=CreateProject.as_view('create_project', self.nm, self.jobs)
         )
         self.app.add_url_rule(
             '/projects/<int:identifier>/label_provider',
-            view_func=ExecuteLabelProvider.as_view('execute_label_provider', self.db,
-                                                   self.notifications, self.jobs)
+            view_func=ExecuteLabelProvider.as_view('execute_label_provider',
+                                                   self.nm, self.jobs)
         )
         self.app.add_url_rule(
             '/projects/<int:identifier>/external_storage',
-            view_func=ExecuteExternalStorage.as_view('execute_external_storage', self.db,
-                                                     self.notifications, self.jobs)
+            view_func=ExecuteExternalStorage.as_view('execute_external_storage',
+                                                     self.nm, self.jobs)
         )
         self.app.add_url_rule(
             '/projects/<int:identifier>/remove',
-            view_func=RemoveProject.as_view('remove_project', self.db, self.notifications)
+            view_func=RemoveProject.as_view('remove_project', self.nm)
         )
         self.app.add_url_rule(
             '/projects/<int:identifier>/name',
-            view_func=EditProjectName.as_view('edit_project_name', self.db, self.notifications)
+            view_func=EditProjectName.as_view('edit_project_name', self.nm)
         )
         self.app.add_url_rule(
             '/projects/<int:identifier>/description',
-            view_func=EditProjectDescription.as_view('edit_project_description', self.db,
-                                                     self.notifications)
+            view_func=EditProjectDescription.as_view('edit_project_description', self.nm)
         )
 
         # pipelines
         self.app.add_url_rule(
             '/projects/<int:project_id>/pipelines/fit',
-            view_func=FitModel.as_view('fit_model', self.db, self.jobs, self.pipelines)
+            view_func=FitModel.as_view('fit_model', self.jobs, self.pipelines)
         )
         self.app.add_url_rule(
             '/projects/<int:project_id>/pipelines/predict',
-            view_func=PredictModel.as_view('predict_model', self.db, self.notifications,
+            view_func=PredictModel.as_view('predict_model', self.nm,
                                            self.jobs, self.pipelines)
         )
         self.app.add_url_rule(
             '/data/<int:file_id>/predict',
-            view_func=PredictFile.as_view('predict_file', self.db, self.notifications,
+            view_func=PredictFile.as_view('predict_file', self.nm,
                                           self.jobs, self.pipelines)
         )
 

+ 2 - 7
pycs/frontend/endpoints/ListProjects.py

@@ -1,7 +1,7 @@
 from flask import jsonify
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.Project import Project
 
 
 class ListProjects(View):
@@ -11,10 +11,5 @@ class ListProjects(View):
     # pylint: disable=arguments-differ
     methods = ['GET']
 
-    def __init__(self, db: Database):
-        # pylint: disable=invalid-name
-        self.db = db
-
     def dispatch_request(self):
-        projects = list(self.db.projects())
-        return jsonify(projects)
+        return jsonify(Project.query.all())

+ 4 - 5
pycs/frontend/endpoints/pipelines/FitModel.py

@@ -16,9 +16,8 @@ class FitModel(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, jobs: JobRunner, pipelines: PipelineCache):
+    def __init__(self, jobs: JobRunner, pipelines: PipelineCache):
         # pylint: disable=invalid-name
-        self.db = db
         self.jobs = jobs
         self.pipelines = pipelines
 
@@ -26,7 +25,7 @@ class FitModel(View):
         # extract request data
         data = request.get_json(force=True)
 
-        if 'fit' not in data or data['fit'] is not True:
+        if not data.get('fit', False):
             return abort(400)
 
         # find project
@@ -40,14 +39,14 @@ class FitModel(View):
                           'Model Interaction',
                           f'{project.name} (fit model with new data)',
                           f'{project.name}/model-interaction',
-                          self.load_and_fit, self.db, project.id)
+                          self.load_and_fit, self.pipelines, project.id)
         except JobGroupBusyException:
             return abort(400)
 
         return make_response()
 
     @staticmethod
-    def load_and_fit(database: Database, pipelines: PipelineCache, project_id: int):
+    def load_and_fit(pipelines: PipelineCache, project_id: int):
         pipeline = None
 
         project = Project.query.get(project_id)

+ 7 - 8
pycs/frontend/endpoints/pipelines/PredictFile.py

@@ -2,7 +2,8 @@ from flask import make_response, request, abort
 from flask.views import View
 
 from pycs.database.Database import Database
-from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel as Predict
+from pycs.database.File import File
+from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel
 from pycs.frontend.notifications.NotificationList import NotificationList
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
@@ -17,10 +18,8 @@ class PredictFile(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self,
-                 db: Database, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
+    def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
         self.jobs = jobs
         self.pipelines = pipelines
@@ -33,7 +32,7 @@ class PredictFile(View):
             return abort(400)
 
         # find file
-        file = self.db.file(file_id)
+        file = File.query.get(file_id)
         if file is None:
             return abort(404)
 
@@ -48,9 +47,9 @@ class PredictFile(View):
                           'Model Interaction',
                           f'{project.name} (create predictions)',
                           f'{project.name}/model-interaction',
-                          Predict.load_and_predict,
-                          self.db, self.pipelines, notifications, project.id, [file],
-                          progress=Predict.progress)
+                          PredictModel.load_and_predict,
+                          self.pipelines, notifications, project.id, [file],
+                          progress=PredictModel.progress)
         except JobGroupBusyException:
             return abort(400)
 

+ 9 - 13
pycs/frontend/endpoints/pipelines/PredictModel.py

@@ -1,9 +1,12 @@
 from typing import Any
 
-from flask import make_response, request, abort
+from flask import abort
+from flask import make_response
+from flask import request
 from flask.views import View
 
 from pycs import app
+from pycs import db
 from pycs.database.Database import Database
 from pycs.database.Project import Project
 from pycs.frontend.notifications.NotificationList import NotificationList
@@ -22,10 +25,8 @@ class PredictModel(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self,
-                 db: Database, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
+    def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
         self.jobs = jobs
         self.pipelines = pipelines
@@ -50,8 +51,8 @@ class PredictModel(View):
                           'Model Interaction',
                           f'{project.name} (create predictions)',
                           f'{project.name}/model-interaction',
-                          self.load_and_predict,
-                          self.db, self.pipelines, notifications,
+                          PredictModel.load_and_predict,
+                          self.pipelines, notifications,
                           project.id, data['predict'],
                           progress=self.progress)
         except JobGroupBusyException:
@@ -60,14 +61,12 @@ class PredictModel(View):
         return make_response()
 
     @staticmethod
-    def load_and_predict(database: Database, pipelines: PipelineCache,
+    def load_and_predict(pipelines: PipelineCache,
                          notifications: NotificationList, project_id: int, file_filter: Any):
-        db = None
         pipeline = None
 
         # create new database instance
         try:
-            db = database.copy()
             project = Project.query.get(project_id)
             model = project.model
             storage = MediaStorage(project_id, notifications)
@@ -101,7 +100,7 @@ class PredictModel(View):
                     pipeline.execute(storage, file)
 
                     # commit changes and yield progress
-                    db.commit()
+                    db.session.commit()
                     yield index / length, notifications
 
                     index += 1
@@ -119,9 +118,6 @@ class PredictModel(View):
             traceback.print_exc()
             app.logger.warning(f"Pipeline Error #1: {e}")
 
-        finally:
-            if db is not None:
-                db.close()
 
     @staticmethod
     def progress(progress: float, notifications: NotificationList):

+ 19 - 15
pycs/frontend/endpoints/projects/CreateProject.py

@@ -1,15 +1,19 @@
 from contextlib import closing
 from os import mkdir
 from os import path
+from pathlib import Path
 from shutil import copytree
 from uuid import uuid1
-from pathlib import Path
 
-from flask import make_response, request, abort
+from flask import abort
+from flask import make_response
+from flask import request
 from flask.views import View
 
 from pycs import app
-from pycs.database.Database import Database
+from pycs import db
+from pycs.database.LabelProvider import LabelProvider
+from pycs.database.Model import Model
 from pycs.database.Project import Project
 from pycs.frontend.endpoints.projects.ExecuteExternalStorage import ExecuteExternalStorage
 from pycs.frontend.endpoints.projects.ExecuteLabelProvider import ExecuteLabelProvider
@@ -25,9 +29,8 @@ class CreateProject(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager, jobs: JobRunner):
+    def __init__(self, nm: NotificationManager, jobs: JobRunner):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
         self.jobs = jobs
 
@@ -50,9 +53,9 @@ class CreateProject(View):
         external_data = data_folder is not None
 
         # start transaction
-        with self.db:
+        with db.session.start():
             # find model
-            model = self.db.model(int(model_id))
+            model = Model.query.get(model_id)
 
             if model is None:
                 return abort(404, "Model not found")
@@ -61,7 +64,7 @@ class CreateProject(View):
             if label_provider_id is None:
                 label_provider = None
             else:
-                label_provider = self.db.label_provider(int(label_provider_id))
+                label_provider = LabelProvider.query.get(label_provider_id)
 
                 if label_provider is None:
                     return abort(404, "Label provider not found")
@@ -91,24 +94,25 @@ class CreateProject(View):
             model, _ = model.copy_to(f'{model.name} ({name})', str(model_folder))
 
             # create entry in database
-            project = self.db.create_project(name, description, model, label_provider,
-                                             str(project_folder), external_data,
-                                             str(data_folder))
+            project = Project.new(name, description, model, label_provider,
+                                  str(project_folder), external_data,
+                                  str(data_folder))
 
         # execute label provider and add labels to project
         if label_provider is not None:
-            ExecuteLabelProvider.execute_label_provider(self.db, self.nm, self.jobs, project,
+            ExecuteLabelProvider.execute_label_provider(self.nm, self.jobs, project,
                                                         label_provider)
 
+        root_folder = model.root_folder
         # load model and add collections to the project
         def load_model_and_get_collections():
-            with closing(load_pipeline(model.root_folder)) as pipeline:
+            with closing(load_pipeline(root_folder)) as pipeline:
                 return pipeline.collections()
 
         project_id = project.id
         def add_collections_to_project(provided_collections):
             project = Project.query.get(project_id)
-            with self.db:
+            with db.session.start():
                 for position, collection in enumerate(provided_collections):
                     project.create_collection(collection['reference'],
                                               collection['name'],
@@ -125,7 +129,7 @@ class CreateProject(View):
 
         # find media files
         if external_data:
-            ExecuteExternalStorage.find_media_files(self.db, self.nm, self.jobs, project)
+            ExecuteExternalStorage.find_media_files(self.nm, self.jobs, project)
 
         # fire event
         self.nm.create_model(model.id)

+ 5 - 6
pycs/frontend/endpoints/projects/EditProjectDescription.py

@@ -1,7 +1,8 @@
 from flask import make_response, abort
 from flask.views import View, request
 
-from pycs.database.Database import Database
+from pycs import db
+from pycs.database.Project import Project
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,9 +13,8 @@ class EditProjectDescription(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, identifier):
@@ -24,10 +24,9 @@ class EditProjectDescription(View):
         if 'description' not in data or not data['description']:
             return abort(400)
 
-        # start transaction
-        with self.db:
+        with db.session.start()
             # find project
-            project = self.db.project(identifier)
+            project = Project.query.get(identifier)
             if project is None:
                 return abort(404)
 

+ 6 - 6
pycs/frontend/endpoints/projects/EditProjectName.py

@@ -1,7 +1,8 @@
 from flask import make_response, abort
 from flask.views import View, request
 
-from pycs.database.Database import Database
+from pycs import db
+from pycs.database.Project import Project
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,22 +13,21 @@ class EditProjectName(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, identifier):
         # extract request data
         data = request.get_json(force=True)
 
-        if 'name' not in data or not data['name']:
+        if data.get('name') is None:
             return abort(400)
 
         # start transaction
-        with self.db:
+        with db.session.start():
             # find project
-            project = self.db.project(identifier)
+            project = Project.query.get(identifier)
             if project is None:
                 return abort(404)
 

+ 27 - 23
pycs/frontend/endpoints/projects/ExecuteExternalStorage.py

@@ -1,12 +1,13 @@
-from os import listdir
-from os import path
-from os.path import isfile
+import os
+
 from uuid import uuid1
 
-from flask import make_response, request, abort
+from flask import abort
+from flask import make_response
+from flask import request
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs import db
 from pycs.database.Project import Project
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
@@ -21,9 +22,8 @@ class ExecuteExternalStorage(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager, jobs: JobRunner):
+    def __init__(self, nm: NotificationManager, jobs: JobRunner):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
         self.jobs = jobs
 
@@ -31,11 +31,11 @@ class ExecuteExternalStorage(View):
         # extract request data
         data = request.get_json(force=True)
 
-        if 'execute' not in data or data['execute'] is not True:
+        if not data.get('execute', False):
             return abort(400)
 
         # find project
-        project = self.db.project(identifier)
+        project = Project.query.get(identifier)
         if project is None:
             return abort(404)
 
@@ -44,44 +44,47 @@ class ExecuteExternalStorage(View):
 
         # execute label provider and add labels to project
         try:
-            self.find_media_files(self.db, self.nm, self.jobs, project)
+            self.find_media_files(self.nm, self.jobs, project)
+
         except JobGroupBusyException:
             return abort(400)
 
         return make_response()
 
     @staticmethod
-    def find_media_files(db: Database, nm: NotificationManager, jobs: JobRunner, project: Project):
+    def find_media_files(nm: NotificationManager, jobs: JobRunner, project: Project):
         """
         start a job that finds media files in the projects data_folder and adds them to the
         database afterwards
 
-        :param db: database object
         :param nm: notification manager object
         :param jobs: job runner object
         :param project: project
         :return:
         """
 
+        data_folder = project.data_folder
+        project_id = project.id
+
         # pylint: disable=invalid-name
         # find lists the given data folder and prepares item dictionaries
         def find():
-            files = listdir(project.data_folder)
+            files = os.listdir(data_folder)
             length = len(files)
 
             elements = []
             current = 0
 
             for file_name in files:
-                file_path = path.join(project.data_folder, file_name)
-                if not isfile(file_path):
+                file_path = os.path.join(data_folder, file_name)
+                if not os.path.isfile(file_path):
                     continue
 
-                file_name, file_extension = path.splitext(file_name)
-                file_size = path.getsize(file_path)
+                file_name, file_extension = os.path.splitext(file_name)
+                file_size = os.path.getsize(file_path)
 
                 try:
-                    ftype, frames, fps = file_info(project.data_folder, file_name, file_extension)
+                    ftype, frames, fps = file_info(data_folder, file_name, file_extension)
                 except ValueError:
                     continue
 
@@ -97,13 +100,14 @@ class ExecuteExternalStorage(View):
 
         # progress inserts elements into the database and fires events
         def progress(elements, current, length):
-            with db:
+            with db.session.start():
+                project = Project.query.get(project_id)
                 for ftype, file_name, file_extension, file_size, frames, fps in elements:
-                    uuid = str(uuid1())
-                    file, insert = project.add_file(uuid, ftype, file_name, file_extension,
-                                                    file_size, file_name, frames, fps)
+                    file, is_new = project.add_file(str(uuid1()), ftype, file_name,
+                                                    file_extension, file_size, file_name,
+                                                    frames, fps)
 
-                    if insert:
+                    if is_new:
                         nm.create_file(file.id)
 
             return current / length

+ 11 - 11
pycs/frontend/endpoints/projects/ExecuteLabelProvider.py

@@ -1,9 +1,11 @@
 from contextlib import closing
 
-from flask import make_response, request, abort
+from flask import abort
+from flask import make_response
+from flask import request
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs import db
 from pycs.database.LabelProvider import LabelProvider
 from pycs.database.Project import Project
 from pycs.frontend.notifications.NotificationManager import NotificationManager
@@ -19,9 +21,8 @@ class ExecuteLabelProvider(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager, jobs: JobRunner):
+    def __init__(self, nm: NotificationManager, jobs: JobRunner):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
         self.jobs = jobs
 
@@ -33,7 +34,7 @@ class ExecuteLabelProvider(View):
             return abort(400)
 
         # find project
-        project = self.db.project(identifier)
+        project = Project.query.get(identifier)
         if project is None:
             return abort(404)
 
@@ -44,20 +45,19 @@ class ExecuteLabelProvider(View):
 
         # execute label provider and add labels to project
         try:
-            self.execute_label_provider(self.db, self.nm, self.jobs, project, label_provider)
+            self.execute_label_provider(self.nm, self.jobs, project, label_provider)
         except JobGroupBusyException:
             return abort(400)
 
         return make_response()
 
     @staticmethod
-    def execute_label_provider(db: Database, nm: NotificationManager, jobs: JobRunner,
+    def execute_label_provider(nm: NotificationManager, jobs: JobRunner,
                                project: Project, label_provider: LabelProvider):
         """
         start a job that loads and executes a label provider and saves its results to the
         database afterwards
 
-        :param db: database object
         :param nm: notification manager object
         :param jobs: job runner object
         :param project: project
@@ -76,11 +76,11 @@ class ExecuteLabelProvider(View):
         # result adds the received labels to the database and fires events
         def result(provided_labels):
             project = Project.query.get(project_id)
-            with db:
+            with db.session.start():
                 for label in provided_labels:
-                    created_label, insert = project.create_label(**label)
+                    created_label, is_new = project.create_label(**label)
 
-                    if insert:
+                    if is_new:
                         nm.create_label(created_label.id)
                     else:
                         nm.edit_label(created_label.id)

+ 10 - 11
pycs/frontend/endpoints/projects/RemoveProject.py

@@ -4,7 +4,7 @@ from flask import make_response, request, abort
 from flask.views import View
 
 from pycs import db
-from pycs.database.Database import Database
+from pycs.database.Project import Project
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -15,9 +15,8 @@ class RemoveProject(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, identifier):
@@ -28,9 +27,9 @@ class RemoveProject(View):
             abort(400)
 
         # start transaction
-        with self.db:
+        with db.session.start():
             # find project
-            project = self.db.project(identifier)
+            project = Project.query.id(identifier)
             if project is None:
                 abort(404, "Project not found")
 
@@ -41,11 +40,11 @@ class RemoveProject(View):
             # remove project from database
             project.remove(commit=False)
 
-            # send update
-            self.nm.remove_model(model.serialize())
-            self.nm.remove_project(project.serialize())
+        # send update
+        self.nm.remove_model(model.serialize())
+        self.nm.remove_project(project.serialize())
 
-            # remove from file system
-            shutil.rmtree(project.root_folder)
+        # remove from file system
+        shutil.rmtree(project.root_folder)
 
-            return make_response()
+        return make_response()