瀏覽代碼

fixed database tests

Dimitri Korsch 3 年之前
父節點
當前提交
96c654a43a
共有 7 個文件被更改,包括 149 次插入73 次删除
  1. 11 0
      Makefile
  2. 9 1
      pycs/__init__.py
  3. 3 2
      pycs/database/Project.py
  4. 4 5
      pycs/frontend/WebServer.py
  5. 0 0
      test/__init__.py
  6. 43 0
      test/base.py
  7. 79 65
      test/test_database.py

+ 11 - 0
Makefile

@@ -0,0 +1,11 @@
+run:
+	python app.py
+
+run_webui:
+	@cd webui && npm run serve
+
+install:
+	@echo "INSTALL MISSING!"
+
+run_tests:
+	python -m unittest discover test/

+ 9 - 1
pycs/__init__.py

@@ -1,4 +1,5 @@
 import json
+import sys
 import os
 
 from pathlib import Path
@@ -20,7 +21,14 @@ if not os.path.exists(settings.projects_folder):
     os.mkdir(settings.projects_folder)
 
 app = Flask(__name__)
-app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{Path.cwd() / settings.database}"
+
+if "unittest" in sys.modules:
+    # creates an in-memory DB
+    db_file = ""
+else:
+    db_file = Path.cwd() / settings.database
+
+app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{db_file}"
 app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
 
 @event.listens_for(Engine, "connect")

+ 3 - 2
pycs/database/Project.py

@@ -72,8 +72,9 @@ class Project(NamedBaseModel):
         if self.model_id is not None:
             model_dump = self.model.delete(commit=False)
 
-        # remove from file system
-        shutil.rmtree(self.root_folder)
+        if os.path.exists(self.root_folder):
+            # remove from file system
+            shutil.rmtree(self.root_folder)
 
         return dump, model_dump
 

+ 4 - 5
pycs/frontend/WebServer.py

@@ -62,9 +62,7 @@ class WebServer:
     wrapper class for flask and socket.io which initializes most networking
     """
 
-    # pylint: disable=line-too-long
-    # pylint: disable=too-many-statements
-    def __init__(self, app, settings: dict):
+    def __init__(self, app, settings: dict, discovery: bool = True):
 
         PRODUCTION = os.path.exists('webui/index.html')
 
@@ -133,8 +131,9 @@ class WebServer:
 
         self.define_routes(jobs, notifications, pipelines)
 
-        Model.discover("models/")
-        LabelProvider.discover("labels/")
+        if discovery:
+            Model.discover("models/")
+            LabelProvider.discover("labels/")
 
 
     def define_routes(self, jobs, notifications, pipelines):

+ 0 - 0
test/__init__.py


+ 43 - 0
test/base.py

@@ -0,0 +1,43 @@
+
+import os
+import shutil
+import unittest
+
+from pycs import app
+from pycs import db
+from pycs import settings
+from pycs.frontend.WebServer import WebServer
+from pycs.database.Model import Model
+from pycs.database.LabelProvider import LabelProvider
+
+server = None
+
+class BaseTestCase(unittest.TestCase):
+    def setUp(self, discovery: bool = True):
+        global server
+        app.config["TESTING"] = True
+        self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
+        app.config["WTF_CSRF_ENABLED"] = False
+        app.config["DEBUG"] = False
+        app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
+
+        db.create_all()
+
+        self.client = app.test_client()
+
+        # init the server once
+        if server is None:
+            server = WebServer(app, settings, discovery)
+
+        elif discovery:
+            # run discovery modules manually
+            Model.discover("models/")
+            LabelProvider.discover("labels/")
+
+    def tearDown(self):
+
+        if os.path.exists(self.projects_dir):
+            shutil.rmtree(self.projects_dir)
+
+        db.drop_all()
+

+ 79 - 65
test/test_database.py

@@ -1,57 +1,66 @@
 import unittest
-from contextlib import closing
 
-from pycs.database.Database import Database
+from pycs import db
+from pycs.database.File import File
+from pycs.database.Label import Label
+from pycs.database.LabelProvider import LabelProvider
+from pycs.database.Model import Model
+from pycs.database.Project import Project
+from pycs.database.Result import Result
 
+from test.base import BaseTestCase
 
-class TestDatabase(unittest.TestCase):
+
+class TestDatabase(BaseTestCase):
     def setUp(self) -> None:
-        # create database
-        self.database = Database(discovery=False)
-
-        # insert default models and label_providers
-        with self.database:
-            with closing(self.database.con.cursor()) as cursor:
-                # models
-                cursor.execute('''
-                    INSERT INTO models (name, description, root_folder, supports)
-                    VALUES 
-                        ('Model 1', 'Description for Model 1', 'modeldir1', '["labeled-image", "fit"]'),
-                        ('Model 2', 'Description for Model 2', 'modeldir2', '["labeled-bounding-boxes"]'),
-                        ('Model 3', 'Description for Model 3', 'modeldir3', '["labeled-bounding-boxes"]')
-                ''')
-
-                # label providers
-                cursor.execute('''
-                    INSERT INTO label_providers (name, description, root_folder, configuration_file)
-                    VALUES
-                        ('Label Provider 1', 'Description for Label Provider 1', 'labeldir1', 'file1'),
-                        ('Label Provider 2', 'Description for Label Provider 2', 'labeldir2', 'file2')
-                ''')
-
-                # projects
-                models = list(self.database.models())
-                label_providers = list(self.database.label_providers())
-
-                for i in range(3):
-                    self.database.create_project(
-                        f'Project {i + 1}', f'Project Description {i + 1}',
-                        models[i],
-                        label_providers[i] if i < 2 else None,
-                        f'projectdir{i + 1}', i == 1, f'datadir{i + 1}'
-                    )
-
-    def tearDown(self) -> None:
-        self.database.close()
+        super().setUp(discovery=False)
+
+        with db.session.begin_nested():
+
+            for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
+                model = Model.new(
+                    commit=False,
+                    name=f"Model {i}",
+                    description=f"Description for Model {i}",
+                    root_folder=f"modeldir{i}",
+                )
+                model.supports = supports
+
+                if i > 2:
+                    continue
+
+                provider = LabelProvider.new(
+                    commit=False,
+                    name=f"Label Provider {i}",
+                    description=f"Description for Label Provider {i}",
+                    root_folder=f"labeldir{i}",
+                    configuration_file=f"labeldir{i}/configuration.json"
+                )
+
+        # projects
+        models = Model.query.all()
+        label_providers = LabelProvider.query.all()
+
+        for i, model in enumerate(models, 1):
+            Project.new(
+                name=f'Project {i}',
+                description=f'Project Description {i}',
+                model=model,
+                label_provider=label_providers[i-1] if i < 3 else None,
+                root_folder=f'projectdir{i}',
+                external_data=i==2,
+                data_folder=f'datadir{i}',
+            )
+
 
     def test_models(self):
-        models = list(self.database.models())
+        models = Model.query.all()
 
         # test length
         self.assertEqual(len(models), 3)
 
         # test insert
-        for i in range(2):
+        for i in range(3):
             self.assertEqual(models[i].id, i + 1)
             self.assertEqual(models[i].name, f'Model {i + 1}')
             self.assertEqual(models[i].description, f'Description for Model {i + 1}')
@@ -59,17 +68,19 @@ class TestDatabase(unittest.TestCase):
 
         self.assertEqual(models[0].supports, ['labeled-image', 'fit'])
         self.assertEqual(models[1].supports, ['labeled-bounding-boxes'])
+        self.assertEqual(models[2].supports, ['labeled-bounding-boxes'])
 
         # test copy
-        copy, _ = models[0].copy_to('Copied Model', 'modeldir3')
-        self.assertEqual(copy.id, 3)
+        copy, is_new = models[0].copy_to('Copied Model', 'some_other_dir')
+        self.assertTrue(is_new)
+        self.assertEqual(copy.id, 4)
         self.assertEqual(copy.name, 'Copied Model')
         self.assertEqual(copy.description, 'Description for Model 1')
-        self.assertEqual(copy.root_folder, 'modeldir3')
+        self.assertEqual(copy.root_folder, 'some_other_dir')
         self.assertEqual(copy.supports, ['labeled-image', 'fit'])
 
     def test_label_providers(self):
-        label_providers = list(self.database.label_providers())
+        label_providers = LabelProvider.query.all()
 
         # test length
         self.assertEqual(len(label_providers), 2)
@@ -79,49 +90,52 @@ class TestDatabase(unittest.TestCase):
             self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}')
             self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}')
             self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
-            self.assertEqual(label_providers[i].configuration_file, f'file{i + 1}')
+            self.assertEqual(label_providers[i].configuration_file,
+                f"labeldir{i + 1}/configuration.json")
 
     def test_projects(self):
-        models = list(self.database.models())
-        label_providers = list(self.database.label_providers())
-        projects = list(self.database.projects())
+        models = Model.query.all()
+        label_providers = LabelProvider.query.all()
+        projects = Project.query.all()
+
+        # get projects
+        self.assertEqual(len(projects), 3)
 
         # create projects
-        for i in range(3):
-            project = projects[i]
+        for i, project in enumerate(projects):
 
             self.assertEqual(project.id, i + 1)
             self.assertEqual(project.name, f'Project {i + 1}')
             self.assertEqual(project.description, f'Project Description {i + 1}')
             self.assertEqual(project.model_id, i + 1)
-            self.assertEqual(project.model().__dict__, models[i].__dict__)
+            self.assertEqual(project.model.__dict__, models[i].__dict__)
             self.assertEqual(project.label_provider_id, label_providers[i].id if i < 2 else None)
             self.assertEqual(
-                project.label_provider().__dict__ if project.label_provider() is not None else None,
+                project.label_provider.__dict__ if project.label_provider is not None else None,
                 label_providers[i].__dict__ if i < 2 else None
             )
             self.assertEqual(project.root_folder, f'projectdir{i + 1}')
             self.assertEqual(project.external_data, i == 1)
             self.assertEqual(project.data_folder, f'datadir{i + 1}')
 
-        # get projects
-        self.assertEqual(len(list(self.database.projects())), 3)
 
         # remove a project
-        list(self.database.projects())[0].remove()
-        projects = list(self.database.projects())
+        Project.query.first().delete()
 
-        self.assertEqual(len(projects), 2)
-        self.assertEqual(projects[0].name, 'Project 2')
+        self.assertEqual(Project.query.count(), 2)
+        self.assertEqual(Project.query.first().name, 'Project 2')
 
         # set properties
-        project = list(self.database.projects())[0]
+        project = Project.query.first()
+
+        project.name = 'Project 0'
+        project.commit()
+        self.assertEqual(Project.query.first().name, 'Project 0')
 
-        project.set_name('Project 0')
-        self.assertEqual(list(self.database.projects())[0].name, 'Project 0')
+        project.description = 'Description 0'
+        project.commit()
+        self.assertEqual(Project.query.first().description, 'Description 0')
 
-        project.set_description('Description 0')
-        self.assertEqual(list(self.database.projects())[0].description, 'Description 0')
 
 
 if __name__ == '__main__':