6
0
Eric Tröbs 4 жил өмнө
parent
commit
1df7964c29

+ 4 - 3
.gitlab-ci.yml

@@ -37,7 +37,7 @@ webui:
     - pip install eventlet flask python-socketio
     - pip install coverage pylint
   script:
-    # - coverage run --source=pycs/ -m unittest discover test/
+    - coverage run --source=pycs/ -m unittest discover test/
     - "pylint --fail-under=9.5
          --disable=duplicate-code
          --disable=missing-module-docstring
@@ -71,8 +71,9 @@ tests_3.9:
   stage: test
   image: python:3.9
   <<: *python_test_definition
-  # after_script:
-  #   - coverage report -m
+  after_script:
+    - source env/bin/activate
+    - coverage report -m
 
 tests_3.10:
   stage: test

+ 0 - 0
pycs/__init__.py


+ 8 - 4
pycs/database/Database.py

@@ -17,7 +17,7 @@ class Database:
     opens an sqlite database and allows to access several objects
     """
 
-    def __init__(self, path: str = ':memory:'):
+    def __init__(self, path: str = ':memory:', discovery=True):
         """
         opens or creates a given sqlite database and creates all required tables
 
@@ -114,9 +114,13 @@ class Database:
             ''')
 
         # run discovery modules
-        with self:
-            discover_models(self.con)
-            discover_label_providers(self.con)
+        if discovery:
+            with self:
+                discover_models(self.con)
+                discover_label_providers(self.con)
+
+    def close(self):
+        self.con.close()
 
     def __enter__(self):
         self.con.__enter__()

+ 0 - 0
pycs/database/__init__.py


+ 0 - 0
pycs/database/discovery/__init__.py


+ 0 - 0
pycs/database/util/__init__.py


+ 0 - 0
pycs/frontend/__init__.py


+ 0 - 0
pycs/frontend/endpoints/__init__.py


+ 0 - 0
pycs/frontend/endpoints/data/__init__.py


+ 0 - 0
pycs/frontend/endpoints/jobs/__init__.py


+ 0 - 0
pycs/frontend/endpoints/labels/__init__.py


+ 0 - 0
pycs/frontend/endpoints/pipelines/__init__.py


+ 0 - 0
pycs/frontend/endpoints/projects/__init__.py


+ 0 - 0
pycs/frontend/endpoints/results/__init__.py


+ 0 - 0
pycs/frontend/notifications/__init__.py


+ 0 - 0
pycs/frontend/util/__init__.py


+ 0 - 0
pycs/interfaces/__init__.py


+ 0 - 0
pycs/jobs/__init__.py


+ 0 - 0
pycs/jobs/util/__init__.py


+ 0 - 0
pycs/util/__init__.py


+ 127 - 0
test/test_database.py

@@ -0,0 +1,127 @@
+import unittest
+from contextlib import closing
+
+from pycs.database.Database import Database
+
+
+class TestDatabase(unittest.TestCase):
+    def setUp(self) -> None:
+        # create database
+        self.database = Database(discovery=False)
+
+        # insert default models and label_providers
+        with self.database:
+            with closing(self.database.con.cursor()) as cursor:
+                # models
+                cursor.execute('''
+                    INSERT INTO models (name, description, root_folder, supports)
+                    VALUES 
+                        ('Model 1', 'Description for Model 1', 'modeldir1', '["labeled-image", "fit"]'),
+                        ('Model 2', 'Description for Model 2', 'modeldir2', '["labeled-bounding-boxes"]'),
+                        ('Model 3', 'Description for Model 3', 'modeldir3', '["labeled-bounding-boxes"]')
+                ''')
+
+                # label providers
+                cursor.execute('''
+                    INSERT INTO label_providers (name, description, root_folder)
+                    VALUES
+                        ('Label Provider 1', 'Description for Label Provider 1', 'labeldir1'),
+                        ('Label Provider 2', 'Description for Label Provider 2', 'labeldir2')
+                ''')
+
+                # projects
+                models = self.database.models()
+                label_providers = self.database.label_providers()
+
+                for i in range(3):
+                    self.database.create_project(
+                        f'Project {i + 1}', f'Project Description {i + 1}',
+                        models[i],
+                        label_providers[i] if i < 2 else None,
+                        f'projectdir{i + 1}', i == 1, f'datadir{i + 1}'
+                    )
+
+    def tearDown(self) -> None:
+        self.database.close()
+
+    def test_models(self):
+        models = self.database.models()
+
+        # test length
+        self.assertEqual(len(models), 3)
+
+        # test insert
+        for i in range(2):
+            self.assertEqual(models[i].identifier, i + 1)
+            self.assertEqual(models[i].name, f'Model {i + 1}')
+            self.assertEqual(models[i].description, f'Description for Model {i + 1}')
+            self.assertEqual(models[i].root_folder, f'modeldir{i + 1}')
+
+        self.assertEqual(models[0].supports, ['labeled-image', 'fit'])
+        self.assertEqual(models[1].supports, ['labeled-bounding-boxes'])
+
+        # test copy
+        copy, _ = models[0].copy_to('Copied Model', 'modeldir3')
+        self.assertEqual(copy.identifier, 3)
+        self.assertEqual(copy.name, 'Copied Model')
+        self.assertEqual(copy.description, 'Description for Model 1')
+        self.assertEqual(copy.root_folder, 'modeldir3')
+        self.assertEqual(copy.supports, ['labeled-image', 'fit'])
+
+    def test_label_providers(self):
+        label_providers = self.database.label_providers()
+
+        # test length
+        self.assertEqual(len(label_providers), 2)
+
+        for i in range(2):
+            self.assertEqual(label_providers[i].identifier, i + 1)
+            self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}')
+            self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}')
+            self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
+
+    def test_projects(self):
+        models = self.database.models()
+        label_providers = self.database.label_providers()
+        projects = self.database.projects()
+
+        # create projects
+        for i in range(3):
+            project = projects[i]
+
+            self.assertEqual(project.identifier, i + 1)
+            self.assertEqual(project.name, f'Project {i + 1}')
+            self.assertEqual(project.description, f'Project Description {i + 1}')
+            self.assertEqual(project.model_id, i + 1)
+            self.assertEqual(project.model().__dict__, models[i].__dict__)
+            self.assertEqual(project.label_provider_id, label_providers[i].identifier if i < 2 else None)
+            self.assertEqual(
+                project.label_provider().__dict__ if project.label_provider() is not None else None,
+                label_providers[i].__dict__ if i < 2 else None
+            )
+            self.assertEqual(project.root_folder, f'projectdir{i + 1}')
+            self.assertEqual(project.external_data, i == 1)
+            self.assertEqual(project.data_folder, f'datadir{i + 1}')
+
+        # get projects
+        self.assertEqual(len(self.database.projects()), 3)
+
+        # remove a project
+        self.database.projects()[0].remove()
+        projects = self.database.projects()
+
+        self.assertEqual(len(projects), 2)
+        self.assertEqual(projects[0].name, 'Project 2')
+
+        # set properties
+        project = self.database.projects()[0]
+
+        project.set_name('Project 0')
+        self.assertEqual(self.database.projects()[0].name, 'Project 0')
+
+        project.set_description('Description 0')
+        self.assertEqual(self.database.projects()[0].description, 'Description 0')
+
+
+if __name__ == '__main__':
+    unittest.main()