|
@@ -1,47 +1,56 @@
|
|
|
import unittest
|
|
|
from contextlib import closing
|
|
|
|
|
|
+from pycs import db
|
|
|
from pycs.database.Database import Database
|
|
|
+from pycs.database.Model import Model
|
|
|
+from pycs.database.LabelProvider import LabelProvider
|
|
|
|
|
|
|
|
|
class TestDatabase(unittest.TestCase):
|
|
|
def setUp(self) -> None:
|
|
|
+ db.create_all()
|
|
|
# create database
|
|
|
self.database = Database(discovery=False)
|
|
|
|
|
|
# insert default models and label_providers
|
|
|
with self.database:
|
|
|
- with closing(self.database.con.cursor()) as cursor:
|
|
|
- # models
|
|
|
- cursor.execute('''
|
|
|
- INSERT INTO models (name, description, root_folder, supports)
|
|
|
- VALUES
|
|
|
- ('Model 1', 'Description for Model 1', 'modeldir1', '["labeled-image", "fit"]'),
|
|
|
- ('Model 2', 'Description for Model 2', 'modeldir2', '["labeled-bounding-boxes"]'),
|
|
|
- ('Model 3', 'Description for Model 3', 'modeldir3', '["labeled-bounding-boxes"]')
|
|
|
- ''')
|
|
|
-
|
|
|
- # label providers
|
|
|
- cursor.execute('''
|
|
|
- INSERT INTO label_providers (name, description, root_folder)
|
|
|
- VALUES
|
|
|
- ('Label Provider 1', 'Description for Label Provider 1', 'labeldir1'),
|
|
|
- ('Label Provider 2', 'Description for Label Provider 2', 'labeldir2')
|
|
|
- ''')
|
|
|
-
|
|
|
- # projects
|
|
|
- models = list(self.database.models())
|
|
|
- label_providers = list(self.database.label_providers())
|
|
|
-
|
|
|
- for i in range(3):
|
|
|
- self.database.create_project(
|
|
|
- f'Project {i + 1}', f'Project Description {i + 1}',
|
|
|
- models[i],
|
|
|
- label_providers[i] if i < 2 else None,
|
|
|
- f'projectdir{i + 1}', i == 1, f'datadir{i + 1}'
|
|
|
- )
|
|
|
+
|
|
|
+ for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
|
|
|
+
|
|
|
+ model = Model.new(
|
|
|
+ name=f"Model {i}",
|
|
|
+ description=f"Description for Model {i}",
|
|
|
+ root_folder=f"modeldir{i}",
|
|
|
+ )
|
|
|
+ model.supports = supports
|
|
|
+
|
|
|
+ if i > 2:
|
|
|
+ continue
|
|
|
+
|
|
|
+ provider = LabelProvider.new(
|
|
|
+ name=f"Label Provider {i}",
|
|
|
+ description=f"Description for Label Provider {i}",
|
|
|
+ root_folder=f"labeldir{i}",
|
|
|
+ )
|
|
|
+
|
|
|
+ # projects
|
|
|
+ models = list(self.database.models())
|
|
|
+ label_providers = list(self.database.label_providers())
|
|
|
+
|
|
|
+ for i, model in enumerate(models, 1):
|
|
|
+ self.database.create_project(
|
|
|
+ name=f'Project {i}',
|
|
|
+ description=f'Project Description {i}',
|
|
|
+ model=model,
|
|
|
+ label_provider=label_providers[i-1] if i < 3 else None,
|
|
|
+ root_folder=f'projectdir{i}',
|
|
|
+ external_data=i==2,
|
|
|
+ data_folder=f'datadir{i}',
|
|
|
+ )
|
|
|
|
|
|
def tearDown(self) -> None:
|
|
|
+ db.drop_all()
|
|
|
self.database.close()
|
|
|
|
|
|
def test_models(self):
|
|
@@ -93,10 +102,10 @@ class TestDatabase(unittest.TestCase):
|
|
|
self.assertEqual(project.name, f'Project {i + 1}')
|
|
|
self.assertEqual(project.description, f'Project Description {i + 1}')
|
|
|
self.assertEqual(project.model_id, i + 1)
|
|
|
- self.assertEqual(project.model().__dict__, models[i].__dict__)
|
|
|
+ self.assertEqual(project.model.__dict__, models[i].__dict__)
|
|
|
self.assertEqual(project.label_provider_id, label_providers[i].id if i < 2 else None)
|
|
|
self.assertEqual(
|
|
|
- project.label_provider().__dict__ if project.label_provider() is not None else None,
|
|
|
+ project.label_provider.__dict__ if project.label_provider is not None else None,
|
|
|
label_providers[i].__dict__ if i < 2 else None
|
|
|
)
|
|
|
self.assertEqual(project.root_folder, f'projectdir{i + 1}')
|