Browse Source

fixed unittests

Dimitri Korsch 3 years ago
parent
commit
81ebf3d3e7
3 changed files with 56 additions and 33 deletions
  1. 12 1
      pycs/__init__.py
  2. 4 1
      pycs/frontend/endpoints/data/GetResizedFile.py
  3. 40 31
      test/test_database.py

+ 12 - 1
pycs/__init__.py

@@ -1,4 +1,5 @@
 import json
+import sys
 
 from pathlib import Path
 
@@ -13,7 +14,17 @@ with open('settings.json') as file:
 
 
 app = Flask(__name__)
-app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{Path.cwd() / settings['database']}"
+
+if "unittest" in sys.modules:
+    # creates an in-memory DB
+    db_file = ""
+    app.logger.setLevel("CRITICAL")
+else:
+    db_file = Path.cwd() / settings['database']
+
+app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{db_file}"
 app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
+
 db = SQLAlchemy(app)
+db.session().execute("PRAGMA foreign_keys=ON")
 migrate = Migrate(app, db)

+ 4 - 1
pycs/frontend/endpoints/data/GetResizedFile.py

@@ -25,7 +25,10 @@ class GetResizedFile(View):
         # get file from database
         file = self.db.file(file_id)
         if file is None:
-            return abort(404)
+            return abort(404, "file object not found")
+
+        if not path.exists(file.path):
+            return abort(404, "image not found!")
 
         project = file.project
 

+ 40 - 31
test/test_database.py

@@ -1,47 +1,56 @@
 import unittest
 from contextlib import closing
 
+from pycs import db
 from pycs.database.Database import Database
+from pycs.database.Model import Model
+from pycs.database.LabelProvider import LabelProvider
 
 
 class TestDatabase(unittest.TestCase):
     def setUp(self) -> None:
+        db.create_all()
         # create database
         self.database = Database(discovery=False)
 
         # insert default models and label_providers
         with self.database:
-            with closing(self.database.con.cursor()) as cursor:
-                # models
-                cursor.execute('''
-                    INSERT INTO models (name, description, root_folder, supports)
-                    VALUES 
-                        ('Model 1', 'Description for Model 1', 'modeldir1', '["labeled-image", "fit"]'),
-                        ('Model 2', 'Description for Model 2', 'modeldir2', '["labeled-bounding-boxes"]'),
-                        ('Model 3', 'Description for Model 3', 'modeldir3', '["labeled-bounding-boxes"]')
-                ''')
-
-                # label providers
-                cursor.execute('''
-                    INSERT INTO label_providers (name, description, root_folder)
-                    VALUES
-                        ('Label Provider 1', 'Description for Label Provider 1', 'labeldir1'),
-                        ('Label Provider 2', 'Description for Label Provider 2', 'labeldir2')
-                ''')
-
-                # projects
-                models = list(self.database.models())
-                label_providers = list(self.database.label_providers())
-
-                for i in range(3):
-                    self.database.create_project(
-                        f'Project {i + 1}', f'Project Description {i + 1}',
-                        models[i],
-                        label_providers[i] if i < 2 else None,
-                        f'projectdir{i + 1}', i == 1, f'datadir{i + 1}'
-                    )
+
+            for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
+
+                model = Model.new(
+                    name=f"Model {i}",
+                    description=f"Description for Model {i}",
+                    root_folder=f"modeldir{i}",
+                )
+                model.supports = supports
+
+                if i > 2:
+                    continue
+
+                provider = LabelProvider.new(
+                    name=f"Label Provider {i}",
+                    description=f"Description for Label Provider {i}",
+                    root_folder=f"labeldir{i}",
+                )
+
+        # projects
+        models = list(self.database.models())
+        label_providers = list(self.database.label_providers())
+
+        for i, model in enumerate(models, 1):
+            self.database.create_project(
+                name=f'Project {i}',
+                description=f'Project Description {i}',
+                model=model,
+                label_provider=label_providers[i-1] if i < 3 else None,
+                root_folder=f'projectdir{i}',
+                external_data=i==2,
+                data_folder=f'datadir{i}',
+            )
 
     def tearDown(self) -> None:
+        db.drop_all()
         self.database.close()
 
     def test_models(self):
@@ -93,10 +102,10 @@ class TestDatabase(unittest.TestCase):
             self.assertEqual(project.name, f'Project {i + 1}')
             self.assertEqual(project.description, f'Project Description {i + 1}')
             self.assertEqual(project.model_id, i + 1)
-            self.assertEqual(project.model().__dict__, models[i].__dict__)
+            self.assertEqual(project.model.__dict__, models[i].__dict__)
             self.assertEqual(project.label_provider_id, label_providers[i].id if i < 2 else None)
             self.assertEqual(
-                project.label_provider().__dict__ if project.label_provider() is not None else None,
+                project.label_provider.__dict__ if project.label_provider is not None else None,
                 label_providers[i].__dict__ if i < 2 else None
             )
             self.assertEqual(project.root_folder, f'projectdir{i + 1}')