123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- import unittest
- from pycs.database.Database import Database
- from pycs.database.File import File
- from pycs.database.Label import Label
- from pycs.database.Result import Result
- from pycs.database.Model import Model
- from pycs.database.LabelProvider import LabelProvider
- from test.base import BaseTestCase
- class DatabaseTests(BaseTestCase):
- def setUp(self) -> None:
- super().setUp(discovery=False)
- # insert default models and label_providers
- with self.database:
- for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
- model = Model.new(
- commit=False,
- name=f"Model {i}",
- description=f"Description for Model {i}",
- root_folder=f"modeldir{i}",
- )
- model.supports = supports
- if i > 2:
- continue
- provider = LabelProvider.new(
- commit=False,
- name=f"Label Provider {i}",
- description=f"Description for Label Provider {i}",
- root_folder=f"labeldir{i}",
- )
- # projects
- models = list(self.database.models())
- label_providers = list(self.database.label_providers())
- for i, model in enumerate(models, 1):
- self.database.create_project(
- name=f'Project {i}',
- description=f'Project Description {i}',
- model=model,
- label_provider=label_providers[i-1] if i < 3 else None,
- root_folder=f'projectdir{i}',
- external_data=i==2,
- data_folder=f'datadir{i}',
- )
- def test_models(self):
- models = list(self.database.models())
- # test length
- self.assertEqual(len(models), 3)
- # test insert
- for i in range(2):
- self.assertEqual(models[i].id, i + 1)
- self.assertEqual(models[i].name, f'Model {i + 1}')
- self.assertEqual(models[i].description, f'Description for Model {i + 1}')
- self.assertEqual(models[i].root_folder, f'modeldir{i + 1}')
- self.assertEqual(models[0].supports, ['labeled-image', 'fit'])
- self.assertEqual(models[1].supports, ['labeled-bounding-boxes'])
- # test copy
- copy, _ = models[0].copy_to('Copied Model', 'modeldir3')
- self.assertEqual(copy.id, 3)
- self.assertEqual(copy.name, 'Copied Model')
- self.assertEqual(copy.description, 'Description for Model 1')
- self.assertEqual(copy.root_folder, 'modeldir3')
- self.assertEqual(copy.supports, ['labeled-image', 'fit'])
- def test_label_providers(self):
- label_providers = list(self.database.label_providers())
- # test length
- self.assertEqual(len(label_providers), 2)
- for i in range(2):
- self.assertEqual(label_providers[i].id, i + 1)
- self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}')
- self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}')
- self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
- def test_projects(self):
- models = list(self.database.models())
- label_providers = list(self.database.label_providers())
- projects = list(self.database.projects())
- # create projects
- for i in range(3):
- project = projects[i]
- self.assertEqual(project.id, i + 1)
- self.assertEqual(project.name, f'Project {i + 1}')
- self.assertEqual(project.description, f'Project Description {i + 1}')
- self.assertEqual(project.model_id, i + 1)
- self.assertEqual(project.model.__dict__, models[i].__dict__)
- self.assertEqual(project.label_provider_id, label_providers[i].id if i < 2 else None)
- self.assertEqual(
- project.label_provider.__dict__ if project.label_provider is not None else None,
- label_providers[i].__dict__ if i < 2 else None
- )
- self.assertEqual(project.root_folder, f'projectdir{i + 1}')
- self.assertEqual(project.external_data, i == 1)
- self.assertEqual(project.data_folder, f'datadir{i + 1}')
- # get projects
- self.assertEqual(len(list(self.database.projects())), 3)
- # remove a project
- list(self.database.projects())[0].remove()
- projects = list(self.database.projects())
- self.assertEqual(len(projects), 2)
- self.assertEqual(projects[0].name, 'Project 2')
- # set properties
- project = list(self.database.projects())[0]
- project.set_name('Project 0')
- self.assertEqual(list(self.database.projects())[0].name, 'Project 0')
- project.set_description('Description 0')
- self.assertEqual(list(self.database.projects())[0].description, 'Description 0')
- def test_no_files_after_project_deletion(self):
- project = self.database.project(1)
- for i in range(5):
- file, is_new = project.add_file(
- uuid=f"some_string{i}",
- name=f"some_name{i}",
- filename=f"some_filename{i}",
- file_type="image",
- extension=".jpg",
- size=42,
- )
- self.assertTrue(is_new)
- self.assertIsNotNone(file)
- self.assertEqual(5, File.query.filter_by(project_id=project.id).count())
- project.remove()
- self.assertIsNone(self.database.project(1))
- self.assertEqual(0, File.query.filter_by(project_id=project.id).count())
- def test_no_labels_after_project_deletion(self):
- self.assertEqual(0, Label.query.count())
- project = self.database.project(1)
- for i in range(5):
- label, is_new = project.create_label(
- name=f"label{i}",
- reference=f"ref{i}"
- )
- self.assertTrue(is_new)
- self.assertIsNotNone(label)
- self.assertEqual(5, Label.query.filter_by(project_id=project.id).count())
- project.remove()
- self.assertIsNone(self.database.project(1))
- self.assertEqual(0, Label.query.count())
- def test_no_results_after_file_deletion(self):
- project = self.database.project(1)
- self.assertIsNotNone(project)
- file, is_new = project.add_file(
- uuid=f"some_string",
- name=f"some_name",
- filename=f"some_filename",
- file_type="image",
- extension=".jpg",
- size=42,
- )
- self.assertIsNotNone(file)
- for i in range(5):
- result = file.create_result(
- origin="pipeline",
- result_type="bounding_box",
- label=None,
- )
- self.assertEqual(5, Result.query.count())
- File.query.filter_by(id=file.id).delete()
- self.assertEqual(0, Result.query.count())
- if __name__ == '__main__':
- unittest.main()
|