import unittest from contextlib import closing from pycs import db from pycs.database.Database import Database from pycs.database.Model import Model from pycs.database.LabelProvider import LabelProvider class TestDatabase(unittest.TestCase): def setUp(self) -> None: db.create_all() # create database self.database = Database(discovery=False) # insert default models and label_providers with self.database: for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1): model = Model.new( name=f"Model {i}", description=f"Description for Model {i}", root_folder=f"modeldir{i}", ) model.supports = supports if i > 2: continue provider = LabelProvider.new( name=f"Label Provider {i}", description=f"Description for Label Provider {i}", root_folder=f"labeldir{i}", ) # projects models = list(self.database.models()) label_providers = list(self.database.label_providers()) for i, model in enumerate(models, 1): self.database.create_project( name=f'Project {i}', description=f'Project Description {i}', model=model, label_provider=label_providers[i-1] if i < 3 else None, root_folder=f'projectdir{i}', external_data=i==2, data_folder=f'datadir{i}', ) def tearDown(self) -> None: db.drop_all() self.database.close() def test_models(self): models = list(self.database.models()) # test length self.assertEqual(len(models), 3) # test insert for i in range(2): self.assertEqual(models[i].id, i + 1) self.assertEqual(models[i].name, f'Model {i + 1}') self.assertEqual(models[i].description, f'Description for Model {i + 1}') self.assertEqual(models[i].root_folder, f'modeldir{i + 1}') self.assertEqual(models[0].supports, ['labeled-image', 'fit']) self.assertEqual(models[1].supports, ['labeled-bounding-boxes']) # test copy copy, _ = models[0].copy_to('Copied Model', 'modeldir3') self.assertEqual(copy.id, 3) self.assertEqual(copy.name, 'Copied Model') self.assertEqual(copy.description, 'Description for Model 1') self.assertEqual(copy.root_folder, 'modeldir3') self.assertEqual(copy.supports, ['labeled-image', 'fit']) def test_label_providers(self): label_providers = list(self.database.label_providers()) # test length self.assertEqual(len(label_providers), 2) for i in range(2): self.assertEqual(label_providers[i].id, i + 1) self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}') self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}') self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}') def test_projects(self): models = list(self.database.models()) label_providers = list(self.database.label_providers()) projects = list(self.database.projects()) # create projects for i in range(3): project = projects[i] self.assertEqual(project.id, i + 1) self.assertEqual(project.name, f'Project {i + 1}') self.assertEqual(project.description, f'Project Description {i + 1}') self.assertEqual(project.model_id, i + 1) self.assertEqual(project.model.__dict__, models[i].__dict__) self.assertEqual(project.label_provider_id, label_providers[i].id if i < 2 else None) self.assertEqual( project.label_provider.__dict__ if project.label_provider is not None else None, label_providers[i].__dict__ if i < 2 else None ) self.assertEqual(project.root_folder, f'projectdir{i + 1}') self.assertEqual(project.external_data, i == 1) self.assertEqual(project.data_folder, f'datadir{i + 1}') # get projects self.assertEqual(len(list(self.database.projects())), 3) # remove a project list(self.database.projects())[0].remove() projects = list(self.database.projects()) self.assertEqual(len(projects), 2) self.assertEqual(projects[0].name, 'Project 2') # set properties project = list(self.database.projects())[0] project.set_name('Project 0') self.assertEqual(list(self.database.projects())[0].name, 'Project 0') project.set_description('Description 0') self.assertEqual(list(self.database.projects())[0].description, 'Description 0') if __name__ == '__main__': unittest.main()