import unittest from contextlib import closing from pycs.database.Database import Database class TestDatabase(unittest.TestCase): def setUp(self) -> None: # create database self.database = Database(discovery=False) # insert default models and label_providers with self.database: with closing(self.database.con.cursor()) as cursor: # models cursor.execute(''' INSERT INTO models (name, description, root_folder, supports) VALUES ('Model 1', 'Description for Model 1', 'modeldir1', '["labeled-image", "fit"]'), ('Model 2', 'Description for Model 2', 'modeldir2', '["labeled-bounding-boxes"]'), ('Model 3', 'Description for Model 3', 'modeldir3', '["labeled-bounding-boxes"]') ''') # label providers cursor.execute(''' INSERT INTO label_providers (name, description, root_folder, configuration_file) VALUES ('Label Provider 1', 'Description for Label Provider 1', 'labeldir1', 'file1'), ('Label Provider 2', 'Description for Label Provider 2', 'labeldir2', 'file2') ''') # projects models = list(self.database.models()) label_providers = list(self.database.label_providers()) for i in range(3): self.database.create_project( f'Project {i + 1}', f'Project Description {i + 1}', models[i], label_providers[i] if i < 2 else None, f'projectdir{i + 1}', i == 1, f'datadir{i + 1}' ) def tearDown(self) -> None: self.database.close() def test_models(self): models = list(self.database.models()) # test length self.assertEqual(len(models), 3) # test insert for i in range(2): self.assertEqual(models[i].identifier, i + 1) self.assertEqual(models[i].name, f'Model {i + 1}') self.assertEqual(models[i].description, f'Description for Model {i + 1}') self.assertEqual(models[i].root_folder, f'modeldir{i + 1}') self.assertEqual(models[0].supports, ['labeled-image', 'fit']) self.assertEqual(models[1].supports, ['labeled-bounding-boxes']) # test copy copy, _ = models[0].copy_to('Copied Model', 'modeldir3') self.assertEqual(copy.identifier, 3) self.assertEqual(copy.name, 'Copied Model') self.assertEqual(copy.description, 'Description for Model 1') self.assertEqual(copy.root_folder, 'modeldir3') self.assertEqual(copy.supports, ['labeled-image', 'fit']) def test_label_providers(self): label_providers = list(self.database.label_providers()) # test length self.assertEqual(len(label_providers), 2) for i in range(2): self.assertEqual(label_providers[i].identifier, i + 1) self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}') self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}') self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}') self.assertEqual(label_providers[i].configuration_file, f'file{i + 1}') def test_projects(self): models = list(self.database.models()) label_providers = list(self.database.label_providers()) projects = list(self.database.projects()) # create projects for i in range(3): project = projects[i] self.assertEqual(project.identifier, i + 1) self.assertEqual(project.name, f'Project {i + 1}') self.assertEqual(project.description, f'Project Description {i + 1}') self.assertEqual(project.model_id, i + 1) self.assertEqual(project.model().__dict__, models[i].__dict__) self.assertEqual(project.label_provider_id, label_providers[i].identifier if i < 2 else None) self.assertEqual( project.label_provider().__dict__ if project.label_provider() is not None else None, label_providers[i].__dict__ if i < 2 else None ) self.assertEqual(project.root_folder, f'projectdir{i + 1}') self.assertEqual(project.external_data, i == 1) self.assertEqual(project.data_folder, f'datadir{i + 1}') # get projects self.assertEqual(len(list(self.database.projects())), 3) # remove a project list(self.database.projects())[0].remove() projects = list(self.database.projects()) self.assertEqual(len(projects), 2) self.assertEqual(projects[0].name, 'Project 2') # set properties project = list(self.database.projects())[0] project.set_name('Project 0') self.assertEqual(list(self.database.projects())[0].name, 'Project 0') project.set_description('Description 0') self.assertEqual(list(self.database.projects())[0].description, 'Description 0') if __name__ == '__main__': unittest.main()