|
@@ -1,57 +1,66 @@
|
|
|
import unittest
|
|
|
-from contextlib import closing
|
|
|
|
|
|
-from pycs.database.Database import Database
|
|
|
+from pycs import db
|
|
|
+from pycs.database.File import File
|
|
|
+from pycs.database.Label import Label
|
|
|
+from pycs.database.LabelProvider import LabelProvider
|
|
|
+from pycs.database.Model import Model
|
|
|
+from pycs.database.Project import Project
|
|
|
+from pycs.database.Result import Result
|
|
|
|
|
|
+from test.base import BaseTestCase
|
|
|
|
|
|
-class TestDatabase(unittest.TestCase):
|
|
|
+
|
|
|
+class TestDatabase(BaseTestCase):
|
|
|
def setUp(self) -> None:
|
|
|
- # create database
|
|
|
- self.database = Database(discovery=False)
|
|
|
-
|
|
|
- # insert default models and label_providers
|
|
|
- with self.database:
|
|
|
- with closing(self.database.con.cursor()) as cursor:
|
|
|
- # models
|
|
|
- cursor.execute('''
|
|
|
- INSERT INTO models (name, description, root_folder, supports)
|
|
|
- VALUES
|
|
|
- ('Model 1', 'Description for Model 1', 'modeldir1', '["labeled-image", "fit"]'),
|
|
|
- ('Model 2', 'Description for Model 2', 'modeldir2', '["labeled-bounding-boxes"]'),
|
|
|
- ('Model 3', 'Description for Model 3', 'modeldir3', '["labeled-bounding-boxes"]')
|
|
|
- ''')
|
|
|
-
|
|
|
- # label providers
|
|
|
- cursor.execute('''
|
|
|
- INSERT INTO label_providers (name, description, root_folder, configuration_file)
|
|
|
- VALUES
|
|
|
- ('Label Provider 1', 'Description for Label Provider 1', 'labeldir1', 'file1'),
|
|
|
- ('Label Provider 2', 'Description for Label Provider 2', 'labeldir2', 'file2')
|
|
|
- ''')
|
|
|
-
|
|
|
- # projects
|
|
|
- models = list(self.database.models())
|
|
|
- label_providers = list(self.database.label_providers())
|
|
|
-
|
|
|
- for i in range(3):
|
|
|
- self.database.create_project(
|
|
|
- f'Project {i + 1}', f'Project Description {i + 1}',
|
|
|
- models[i],
|
|
|
- label_providers[i] if i < 2 else None,
|
|
|
- f'projectdir{i + 1}', i == 1, f'datadir{i + 1}'
|
|
|
- )
|
|
|
-
|
|
|
- def tearDown(self) -> None:
|
|
|
- self.database.close()
|
|
|
+ super().setUp(discovery=False)
|
|
|
+
|
|
|
+ with db.session.begin_nested():
|
|
|
+
|
|
|
+ for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
|
|
|
+ model = Model.new(
|
|
|
+ commit=False,
|
|
|
+ name=f"Model {i}",
|
|
|
+ description=f"Description for Model {i}",
|
|
|
+ root_folder=f"modeldir{i}",
|
|
|
+ )
|
|
|
+ model.supports = supports
|
|
|
+
|
|
|
+ if i > 2:
|
|
|
+ continue
|
|
|
+
|
|
|
+ provider = LabelProvider.new(
|
|
|
+ commit=False,
|
|
|
+ name=f"Label Provider {i}",
|
|
|
+ description=f"Description for Label Provider {i}",
|
|
|
+ root_folder=f"labeldir{i}",
|
|
|
+ configuration_file=f"labeldir{i}/configuration.json"
|
|
|
+ )
|
|
|
+
|
|
|
+ # projects
|
|
|
+ models = Model.query.all()
|
|
|
+ label_providers = LabelProvider.query.all()
|
|
|
+
|
|
|
+ for i, model in enumerate(models, 1):
|
|
|
+ Project.new(
|
|
|
+ name=f'Project {i}',
|
|
|
+ description=f'Project Description {i}',
|
|
|
+ model=model,
|
|
|
+ label_provider=label_providers[i-1] if i < 3 else None,
|
|
|
+ root_folder=f'projectdir{i}',
|
|
|
+ external_data=i==2,
|
|
|
+ data_folder=f'datadir{i}',
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
def test_models(self):
|
|
|
- models = list(self.database.models())
|
|
|
+ models = Model.query.all()
|
|
|
|
|
|
# test length
|
|
|
self.assertEqual(len(models), 3)
|
|
|
|
|
|
# test insert
|
|
|
- for i in range(2):
|
|
|
+ for i in range(3):
|
|
|
self.assertEqual(models[i].id, i + 1)
|
|
|
self.assertEqual(models[i].name, f'Model {i + 1}')
|
|
|
self.assertEqual(models[i].description, f'Description for Model {i + 1}')
|
|
@@ -59,17 +68,19 @@ class TestDatabase(unittest.TestCase):
|
|
|
|
|
|
self.assertEqual(models[0].supports, ['labeled-image', 'fit'])
|
|
|
self.assertEqual(models[1].supports, ['labeled-bounding-boxes'])
|
|
|
+ self.assertEqual(models[2].supports, ['labeled-bounding-boxes'])
|
|
|
|
|
|
# test copy
|
|
|
- copy, _ = models[0].copy_to('Copied Model', 'modeldir3')
|
|
|
- self.assertEqual(copy.id, 3)
|
|
|
+ copy, is_new = models[0].copy_to('Copied Model', 'some_other_dir')
|
|
|
+ self.assertTrue(is_new)
|
|
|
+ self.assertEqual(copy.id, 4)
|
|
|
self.assertEqual(copy.name, 'Copied Model')
|
|
|
self.assertEqual(copy.description, 'Description for Model 1')
|
|
|
- self.assertEqual(copy.root_folder, 'modeldir3')
|
|
|
+ self.assertEqual(copy.root_folder, 'some_other_dir')
|
|
|
self.assertEqual(copy.supports, ['labeled-image', 'fit'])
|
|
|
|
|
|
def test_label_providers(self):
|
|
|
- label_providers = list(self.database.label_providers())
|
|
|
+ label_providers = LabelProvider.query.all()
|
|
|
|
|
|
# test length
|
|
|
self.assertEqual(len(label_providers), 2)
|
|
@@ -79,49 +90,52 @@ class TestDatabase(unittest.TestCase):
|
|
|
self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}')
|
|
|
self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}')
|
|
|
self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
|
|
|
- self.assertEqual(label_providers[i].configuration_file, f'file{i + 1}')
|
|
|
+ self.assertEqual(label_providers[i].configuration_file,
|
|
|
+ f"labeldir{i + 1}/configuration.json")
|
|
|
|
|
|
def test_projects(self):
|
|
|
- models = list(self.database.models())
|
|
|
- label_providers = list(self.database.label_providers())
|
|
|
- projects = list(self.database.projects())
|
|
|
+ models = Model.query.all()
|
|
|
+ label_providers = LabelProvider.query.all()
|
|
|
+ projects = Project.query.all()
|
|
|
+
|
|
|
+ # get projects
|
|
|
+ self.assertEqual(len(projects), 3)
|
|
|
|
|
|
# create projects
|
|
|
- for i in range(3):
|
|
|
- project = projects[i]
|
|
|
+ for i, project in enumerate(projects):
|
|
|
|
|
|
self.assertEqual(project.id, i + 1)
|
|
|
self.assertEqual(project.name, f'Project {i + 1}')
|
|
|
self.assertEqual(project.description, f'Project Description {i + 1}')
|
|
|
self.assertEqual(project.model_id, i + 1)
|
|
|
- self.assertEqual(project.model().__dict__, models[i].__dict__)
|
|
|
+ self.assertEqual(project.model.__dict__, models[i].__dict__)
|
|
|
self.assertEqual(project.label_provider_id, label_providers[i].id if i < 2 else None)
|
|
|
self.assertEqual(
|
|
|
- project.label_provider().__dict__ if project.label_provider() is not None else None,
|
|
|
+ project.label_provider.__dict__ if project.label_provider is not None else None,
|
|
|
label_providers[i].__dict__ if i < 2 else None
|
|
|
)
|
|
|
self.assertEqual(project.root_folder, f'projectdir{i + 1}')
|
|
|
self.assertEqual(project.external_data, i == 1)
|
|
|
self.assertEqual(project.data_folder, f'datadir{i + 1}')
|
|
|
|
|
|
- # get projects
|
|
|
- self.assertEqual(len(list(self.database.projects())), 3)
|
|
|
|
|
|
# remove a project
|
|
|
- list(self.database.projects())[0].remove()
|
|
|
- projects = list(self.database.projects())
|
|
|
+ Project.query.first().delete()
|
|
|
|
|
|
- self.assertEqual(len(projects), 2)
|
|
|
- self.assertEqual(projects[0].name, 'Project 2')
|
|
|
+ self.assertEqual(Project.query.count(), 2)
|
|
|
+ self.assertEqual(Project.query.first().name, 'Project 2')
|
|
|
|
|
|
# set properties
|
|
|
- project = list(self.database.projects())[0]
|
|
|
+ project = Project.query.first()
|
|
|
+
|
|
|
+ project.name = 'Project 0'
|
|
|
+ project.commit()
|
|
|
+ self.assertEqual(Project.query.first().name, 'Project 0')
|
|
|
|
|
|
- project.set_name('Project 0')
|
|
|
- self.assertEqual(list(self.database.projects())[0].name, 'Project 0')
|
|
|
+ project.description = 'Description 0'
|
|
|
+ project.commit()
|
|
|
+ self.assertEqual(Project.query.first().description, 'Description 0')
|
|
|
|
|
|
- project.set_description('Description 0')
|
|
|
- self.assertEqual(list(self.database.projects())[0].description, 'Description 0')
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|