123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- import unittest
- from contextlib import closing
- from pycs.database.Database import Database
- class TestDatabase(unittest.TestCase):
- def setUp(self) -> None:
- # create database
- self.database = Database(discovery=False)
- # insert default models and label_providers
- with self.database:
- with closing(self.database.con.cursor()) as cursor:
- # models
- cursor.execute('''
- INSERT INTO models (name, description, root_folder, supports)
- VALUES
- ('Model 1', 'Description for Model 1', 'modeldir1', '["labeled-image", "fit"]'),
- ('Model 2', 'Description for Model 2', 'modeldir2', '["labeled-bounding-boxes"]'),
- ('Model 3', 'Description for Model 3', 'modeldir3', '["labeled-bounding-boxes"]')
- ''')
- # label providers
- cursor.execute('''
- INSERT INTO label_providers (name, description, root_folder)
- VALUES
- ('Label Provider 1', 'Description for Label Provider 1', 'labeldir1'),
- ('Label Provider 2', 'Description for Label Provider 2', 'labeldir2')
- ''')
- # projects
- models = self.database.models()
- label_providers = self.database.label_providers()
- for i in range(3):
- self.database.create_project(
- f'Project {i + 1}', f'Project Description {i + 1}',
- models[i],
- label_providers[i] if i < 2 else None,
- f'projectdir{i + 1}', i == 1, f'datadir{i + 1}'
- )
- def tearDown(self) -> None:
- self.database.close()
- def test_models(self):
- models = self.database.models()
- # test length
- self.assertEqual(len(models), 3)
- # test insert
- for i in range(2):
- self.assertEqual(models[i].identifier, i + 1)
- self.assertEqual(models[i].name, f'Model {i + 1}')
- self.assertEqual(models[i].description, f'Description for Model {i + 1}')
- self.assertEqual(models[i].root_folder, f'modeldir{i + 1}')
- self.assertEqual(models[0].supports, ['labeled-image', 'fit'])
- self.assertEqual(models[1].supports, ['labeled-bounding-boxes'])
- # test copy
- copy, _ = models[0].copy_to('Copied Model', 'modeldir3')
- self.assertEqual(copy.identifier, 3)
- self.assertEqual(copy.name, 'Copied Model')
- self.assertEqual(copy.description, 'Description for Model 1')
- self.assertEqual(copy.root_folder, 'modeldir3')
- self.assertEqual(copy.supports, ['labeled-image', 'fit'])
- def test_label_providers(self):
- label_providers = self.database.label_providers()
- # test length
- self.assertEqual(len(label_providers), 2)
- for i in range(2):
- self.assertEqual(label_providers[i].identifier, i + 1)
- self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}')
- self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}')
- self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
- def test_projects(self):
- models = self.database.models()
- label_providers = self.database.label_providers()
- projects = self.database.projects()
- # create projects
- for i in range(3):
- project = projects[i]
- self.assertEqual(project.identifier, i + 1)
- self.assertEqual(project.name, f'Project {i + 1}')
- self.assertEqual(project.description, f'Project Description {i + 1}')
- self.assertEqual(project.model_id, i + 1)
- self.assertEqual(project.model().__dict__, models[i].__dict__)
- self.assertEqual(project.label_provider_id, label_providers[i].identifier if i < 2 else None)
- self.assertEqual(
- project.label_provider().__dict__ if project.label_provider() is not None else None,
- label_providers[i].__dict__ if i < 2 else None
- )
- self.assertEqual(project.root_folder, f'projectdir{i + 1}')
- self.assertEqual(project.external_data, i == 1)
- self.assertEqual(project.data_folder, f'datadir{i + 1}')
- # get projects
- self.assertEqual(len(self.database.projects()), 3)
- # remove a project
- self.database.projects()[0].remove()
- projects = self.database.projects()
- self.assertEqual(len(projects), 2)
- self.assertEqual(projects[0].name, 'Project 2')
- # set properties
- project = self.database.projects()[0]
- project.set_name('Project 0')
- self.assertEqual(self.database.projects()[0].name, 'Project 0')
- project.set_description('Description 0')
- self.assertEqual(self.database.projects()[0].description, 'Description 0')
- if __name__ == '__main__':
- unittest.main()
|