Browse Source

removed Database class. Fixed tests and other places, where it was used

Dimitri Korsch 3 years ago
parent
commit
f3ce674cbe

+ 0 - 168
pycs/database/Database.py

@@ -1,168 +0,0 @@
-import sqlite3
-from contextlib import closing
-from time import time
-from typing import Optional, Iterator
-
-from pycs import app
-from pycs import db
-from pycs.database.Collection import Collection
-from pycs.database.File import File
-from pycs.database.LabelProvider import LabelProvider
-from pycs.database.Model import Model
-from pycs.database.Project import Project
-from pycs.database.Result import Result
-
-
-class Database:
-    """
-    opens an sqlite database and allows to access several objects
-    """
-
-    def __init__(self, initialization=True, discovery=True):
-        """
-        opens or creates a given sqlite database and creates all required tables
-
-        :param path: path to sqlite database
-        """
-
-        if discovery:
-            # run discovery modules
-            Model.discover("models/")
-            LabelProvider.discover("labels/")
-
-    def __enter__(self):
-        # app.logger.warning("Database.__enter__(): REMOVE ME!")
-        return db
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        # app.logger.warning("Database.__exit__(): REMOVE ME!")
-
-        if exc_type is None:
-            db.session.commit()
-        else:
-            app.logger.error("Rolling back a transaction!")
-            db.session.rollback()
-
-    def close(self):
-        app.logger.warning("Database.close(): REMOVE ME!")
-
-    def commit(self):
-        """
-        commit changes
-        """
-        db.session.commit()
-
-    def copy(self):
-        return Database(initialization=False, discovery=False)
-
-    def models(self) -> Iterator[Model]:
-        """
-        get a list of all available models
-
-        :return: iterator of models
-        """
-        return Model.query.all()
-
-    def model(self, identifier: int) -> Optional[Model]:
-        """
-        get a model using its unique identifier
-
-        :param identifier: unique identifier
-        :return: model
-        """
-        return Model.query.get(identifier)
-
-    def label_providers(self) -> Iterator[LabelProvider]:
-        """
-        get a list of all available label providers
-
-        :return: iterator over label providers
-        """
-        return LabelProvider.query.all()
-
-    def label_provider(self, identifier: int) -> Optional[LabelProvider]:
-        """
-        get a label provider using its unique identifier
-
-        :param identifier: unique identifier
-        :return: label provider
-        """
-        return LabelProvider.query.get(identifier)
-
-    def projects(self) -> Iterator[Project]:
-        """
-        get a list of all available projects
-
-        :return: iterator over projects
-        """
-        return Project.query.all()
-
-    def project(self, identifier: int) -> Optional[Project]:
-        """
-        get a project using its unique identifier
-
-        :param identifier: unique identifier
-        :return: project
-        """
-        return Project.query.get(identifier)
-
-    def collection(self, identifier: int) -> Optional[Collection]:
-        """
-        get a collection using its unique identifier
-
-        :param identifier: unique identifier
-        :return: collection
-        """
-        return Collection.query.get(identifier)
-
-    def file(self, identifier) -> Optional[File]:
-        """
-        get a file using its unique identifier
-
-        :param identifier: unique identifier
-        :return: file
-        """
-        return File.query.get(identifier)
-
-    def result(self, identifier) -> Optional[Result]:
-        """
-        get a result using its unique identifier
-
-        :param identifier: unique identifier
-        :return: result
-        """
-        return Result.query.get(identifier)
-
-    def create_project(self,
-                       name: str,
-                       description: str,
-                       model: Model,
-                       label_provider: Optional[LabelProvider],
-                       root_folder: str,
-                       external_data: bool,
-                       data_folder: str,
-                       commit: bool = True):
-        """
-        insert a project into the database
-
-        :param name: project name
-        :param description: project description
-        :param model: used model
-        :param label_provider: used label provider (optional)
-        :param root_folder: path to project folder
-        :param external_data: whether an external data directory is used
-        :param data_folder: path to data folder
-        :return: created project
-        """
-        # prepare some values
-
-        return Project.new(
-            commit=commit,
-            name=name,
-            description=description,
-            model=model,
-            label_provider=label_provider,
-            root_folder=root_folder,
-            external_data=external_data,
-            data_folder=data_folder
-        )

+ 3 - 2
pycs/database/base.py

@@ -78,6 +78,7 @@ class NamedBaseModel(BaseModel):
 
     serialize_only = BaseModel.serialize_only + ("name", )
 
-    def set_name(self, name: str):
+    def set_name(self, name: str, commit: bool = True):
         self.name = name
-        self.commit()
+        if commit:
+            self.commit()

+ 3 - 3
pycs/frontend/WebServer.py

@@ -11,7 +11,6 @@ from logging import config
 from pycs import app
 
 from pycs.database.Collection import Collection
-from pycs.database.Database import Database
 from pycs.database.LabelProvider import LabelProvider
 from pycs.database.Label import Label
 from pycs.database.Model import Model
@@ -68,8 +67,9 @@ class WebServer:
         # initialize flask app instance
         self.app = app
 
-        # initialize database
-        self.db = Database()
+        # run discovery modules
+        Model.discover("models/")
+        LabelProvider.discover("labels/")
 
         # start job runner
         self.logger.info('Starting job runner... ')

+ 3 - 4
pycs/frontend/endpoints/labels/EditLabelName.py

@@ -1,7 +1,7 @@
 from flask import request, abort, make_response
 from flask.views import View
 
-from pycs.database.Database import Database
+from pycs.database.Project import Project
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
@@ -12,9 +12,8 @@ class EditLabelName(View):
     # pylint: disable=arguments-differ
     methods = ['POST']
 
-    def __init__(self, db: Database, nm: NotificationManager):
+    def __init__(self, nm: NotificationManager):
         # pylint: disable=invalid-name
-        self.db = db
         self.nm = nm
 
     def dispatch_request(self, project_id: int, label_id: int):
@@ -25,7 +24,7 @@ class EditLabelName(View):
             abort(400)
 
         # find project
-        project = self.db.project(project_id)
+        project = Project.query.get(project_id)
         if project is None:
             abort(404)
 

+ 7 - 2
test/base.py

@@ -7,7 +7,8 @@ from pycs import app
 from pycs import db
 from pycs import settings
 from pycs.frontend.WebServer import WebServer
-from pycs.database.Database import Database
+from pycs.database.Model import Model
+from pycs.database.LabelProvider import LabelProvider
 
 server = None
 
@@ -30,7 +31,11 @@ class BaseTestCase(unittest.TestCase):
         server.start_runner()
 
         # create database
-        self.database = Database(discovery=discovery)
+        if discovery:
+            # run discovery modules
+            Model.discover("models/")
+            LabelProvider.discover("labels/")
+        # self.database = Database(discovery=discovery)
 
     def tearDown(self):
         global server

+ 25 - 25
test/test_database.py

@@ -1,11 +1,12 @@
 import unittest
 
-from pycs.database.Database import Database
+from pycs import db
 from pycs.database.File import File
 from pycs.database.Label import Label
-from pycs.database.Result import Result
-from pycs.database.Model import Model
 from pycs.database.LabelProvider import LabelProvider
+from pycs.database.Model import Model
+from pycs.database.Project import Project
+from pycs.database.Result import Result
 
 from test.base import BaseTestCase
 
@@ -15,7 +16,7 @@ class DatabaseTests(BaseTestCase):
         super().setUp(discovery=False)
 
         # insert default models and label_providers
-        with self.database:
+        with db.session.begin_nested():
 
             for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
 
@@ -39,11 +40,11 @@ class DatabaseTests(BaseTestCase):
                 )
 
         # projects
-        models = list(self.database.models())
-        label_providers = list(self.database.label_providers())
+        models = Model.query.all()
+        label_providers = LabelProvider.query.all()
 
         for i, model in enumerate(models, 1):
-            self.database.create_project(
+            Project.new(
                 name=f'Project {i}',
                 description=f'Project Description {i}',
                 model=model,
@@ -54,7 +55,7 @@ class DatabaseTests(BaseTestCase):
             )
 
     def test_models(self):
-        models = list(self.database.models())
+        models = Model.query.all()
 
         # test length
         self.assertEqual(len(models), 3)
@@ -78,7 +79,7 @@ class DatabaseTests(BaseTestCase):
         self.assertEqual(copy.supports, ['labeled-image', 'fit'])
 
     def test_label_providers(self):
-        label_providers = list(self.database.label_providers())
+        label_providers = LabelProvider.query.all()
 
         # test length
         self.assertEqual(len(label_providers), 2)
@@ -90,9 +91,9 @@ class DatabaseTests(BaseTestCase):
             self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
 
     def test_projects(self):
-        models = list(self.database.models())
-        label_providers = list(self.database.label_providers())
-        projects = list(self.database.projects())
+        models = Model.query.all()
+        label_providers = LabelProvider.query.all()
+        projects = Project.query.all()
 
         # create projects
         for i in range(3):
@@ -113,28 +114,27 @@ class DatabaseTests(BaseTestCase):
             self.assertEqual(project.data_folder, f'datadir{i + 1}')
 
         # get projects
-        self.assertEqual(len(list(self.database.projects())), 3)
+        self.assertEqual(Project.query.count(), 3)
 
         # remove a project
-        list(self.database.projects())[0].remove()
-        projects = list(self.database.projects())
+        Project.query.first().remove()
 
-        self.assertEqual(len(projects), 2)
-        self.assertEqual(projects[0].name, 'Project 2')
+        self.assertEqual(Project.query.count(), 2)
+        self.assertEqual(Project.query.first().name, 'Project 2')
 
         # set properties
-        project = list(self.database.projects())[0]
+        project = Project.query.first()
 
         project.set_name('Project 0')
-        self.assertEqual(list(self.database.projects())[0].name, 'Project 0')
+        self.assertEqual(Project.query.first().name, 'Project 0')
 
         project.set_description('Description 0')
-        self.assertEqual(list(self.database.projects())[0].description, 'Description 0')
+        self.assertEqual(Project.query.first().description, 'Description 0')
 
 
     def test_no_files_after_project_deletion(self):
 
-        project = self.database.project(1)
+        project = Project.query.get(1)
         for i in range(5):
             file, is_new = project.add_file(
                 uuid=f"some_string{i}",
@@ -151,13 +151,13 @@ class DatabaseTests(BaseTestCase):
         self.assertEqual(5, File.query.filter_by(project_id=project.id).count())
 
         project.remove()
-        self.assertIsNone(self.database.project(1))
+        self.assertIsNone(Project.query.get(1))
         self.assertEqual(0, File.query.filter_by(project_id=project.id).count())
 
     def test_no_labels_after_project_deletion(self):
 
         self.assertEqual(0, Label.query.count())
-        project = self.database.project(1)
+        project = Project.query.get(1)
         for i in range(5):
             label, is_new = project.create_label(
                 name=f"label{i}",
@@ -171,13 +171,13 @@ class DatabaseTests(BaseTestCase):
 
         project.remove()
 
-        self.assertIsNone(self.database.project(1))
+        self.assertIsNone(Project.query.get(1))
         self.assertEqual(0, Label.query.count())
 
 
     def test_no_results_after_file_deletion(self):
 
-        project = self.database.project(1)
+        project = Project.query.get(1)
         self.assertIsNotNone(project)
 
         file, is_new = project.add_file(