test_database.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import unittest
  2. from contextlib import closing
  3. from pycs.database.Database import Database
  4. class TestDatabase(unittest.TestCase):
  5. def setUp(self) -> None:
  6. # create database
  7. self.database = Database(discovery=False)
  8. # insert default models and label_providers
  9. with self.database:
  10. with closing(self.database.con.cursor()) as cursor:
  11. # models
  12. cursor.execute('''
  13. INSERT INTO models (name, description, root_folder, supports)
  14. VALUES
  15. ('Model 1', 'Description for Model 1', 'modeldir1', '["labeled-image", "fit"]'),
  16. ('Model 2', 'Description for Model 2', 'modeldir2', '["labeled-bounding-boxes"]'),
  17. ('Model 3', 'Description for Model 3', 'modeldir3', '["labeled-bounding-boxes"]')
  18. ''')
  19. # label providers
  20. cursor.execute('''
  21. INSERT INTO label_providers (name, description, root_folder, configuration_file)
  22. VALUES
  23. ('Label Provider 1', 'Description for Label Provider 1', 'labeldir1', 'file1'),
  24. ('Label Provider 2', 'Description for Label Provider 2', 'labeldir2', 'file2')
  25. ''')
  26. # projects
  27. models = list(self.database.models())
  28. label_providers = list(self.database.label_providers())
  29. for i in range(3):
  30. self.database.create_project(
  31. f'Project {i + 1}', f'Project Description {i + 1}',
  32. models[i],
  33. label_providers[i] if i < 2 else None,
  34. f'projectdir{i + 1}', i == 1, f'datadir{i + 1}'
  35. )
  36. def tearDown(self) -> None:
  37. self.database.close()
  38. def test_models(self):
  39. models = list(self.database.models())
  40. # test length
  41. self.assertEqual(len(models), 3)
  42. # test insert
  43. for i in range(2):
  44. self.assertEqual(models[i].identifier, i + 1)
  45. self.assertEqual(models[i].name, f'Model {i + 1}')
  46. self.assertEqual(models[i].description, f'Description for Model {i + 1}')
  47. self.assertEqual(models[i].root_folder, f'modeldir{i + 1}')
  48. self.assertEqual(models[0].supports, ['labeled-image', 'fit'])
  49. self.assertEqual(models[1].supports, ['labeled-bounding-boxes'])
  50. # test copy
  51. copy, _ = models[0].copy_to('Copied Model', 'modeldir3')
  52. self.assertEqual(copy.identifier, 3)
  53. self.assertEqual(copy.name, 'Copied Model')
  54. self.assertEqual(copy.description, 'Description for Model 1')
  55. self.assertEqual(copy.root_folder, 'modeldir3')
  56. self.assertEqual(copy.supports, ['labeled-image', 'fit'])
  57. def test_label_providers(self):
  58. label_providers = list(self.database.label_providers())
  59. # test length
  60. self.assertEqual(len(label_providers), 2)
  61. for i in range(2):
  62. self.assertEqual(label_providers[i].identifier, i + 1)
  63. self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}')
  64. self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}')
  65. self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
  66. self.assertEqual(label_providers[i].configuration_file, f'file{i + 1}')
  67. def test_projects(self):
  68. models = list(self.database.models())
  69. label_providers = list(self.database.label_providers())
  70. projects = list(self.database.projects())
  71. # create projects
  72. for i in range(3):
  73. project = projects[i]
  74. self.assertEqual(project.identifier, i + 1)
  75. self.assertEqual(project.name, f'Project {i + 1}')
  76. self.assertEqual(project.description, f'Project Description {i + 1}')
  77. self.assertEqual(project.model_id, i + 1)
  78. self.assertEqual(project.model().__dict__, models[i].__dict__)
  79. self.assertEqual(project.label_provider_id, label_providers[i].identifier if i < 2 else None)
  80. self.assertEqual(
  81. project.label_provider().__dict__ if project.label_provider() is not None else None,
  82. label_providers[i].__dict__ if i < 2 else None
  83. )
  84. self.assertEqual(project.root_folder, f'projectdir{i + 1}')
  85. self.assertEqual(project.external_data, i == 1)
  86. self.assertEqual(project.data_folder, f'datadir{i + 1}')
  87. # get projects
  88. self.assertEqual(len(list(self.database.projects())), 3)
  89. # remove a project
  90. list(self.database.projects())[0].remove()
  91. projects = list(self.database.projects())
  92. self.assertEqual(len(projects), 2)
  93. self.assertEqual(projects[0].name, 'Project 2')
  94. # set properties
  95. project = list(self.database.projects())[0]
  96. project.set_name('Project 0')
  97. self.assertEqual(list(self.database.projects())[0].name, 'Project 0')
  98. project.set_description('Description 0')
  99. self.assertEqual(list(self.database.projects())[0].description, 'Description 0')
  100. if __name__ == '__main__':
  101. unittest.main()