6
0

test_database.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import unittest
  2. from contextlib import closing
  3. from pycs import db
  4. from pycs.database.Database import Database
  5. from pycs.database.File import File
  6. from pycs.database.Label import Label
  7. from pycs.database.Model import Model
  8. from pycs.database.LabelProvider import LabelProvider
  9. class TestDatabase(unittest.TestCase):
  10. def setUp(self) -> None:
  11. db.create_all()
  12. # create database
  13. self.database = Database(discovery=False)
  14. # insert default models and label_providers
  15. with self.database:
  16. for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
  17. model = Model.new(
  18. name=f"Model {i}",
  19. description=f"Description for Model {i}",
  20. root_folder=f"modeldir{i}",
  21. )
  22. model.supports = supports
  23. if i > 2:
  24. continue
  25. provider = LabelProvider.new(
  26. name=f"Label Provider {i}",
  27. description=f"Description for Label Provider {i}",
  28. root_folder=f"labeldir{i}",
  29. )
  30. # projects
  31. models = list(self.database.models())
  32. label_providers = list(self.database.label_providers())
  33. for i, model in enumerate(models, 1):
  34. self.database.create_project(
  35. name=f'Project {i}',
  36. description=f'Project Description {i}',
  37. model=model,
  38. label_provider=label_providers[i-1] if i < 3 else None,
  39. root_folder=f'projectdir{i}',
  40. external_data=i==2,
  41. data_folder=f'datadir{i}',
  42. )
  43. def tearDown(self) -> None:
  44. db.drop_all()
  45. self.database.close()
  46. def test_models(self):
  47. models = list(self.database.models())
  48. # test length
  49. self.assertEqual(len(models), 3)
  50. # test insert
  51. for i in range(2):
  52. self.assertEqual(models[i].id, i + 1)
  53. self.assertEqual(models[i].name, f'Model {i + 1}')
  54. self.assertEqual(models[i].description, f'Description for Model {i + 1}')
  55. self.assertEqual(models[i].root_folder, f'modeldir{i + 1}')
  56. self.assertEqual(models[0].supports, ['labeled-image', 'fit'])
  57. self.assertEqual(models[1].supports, ['labeled-bounding-boxes'])
  58. # test copy
  59. copy, _ = models[0].copy_to('Copied Model', 'modeldir3')
  60. self.assertEqual(copy.id, 3)
  61. self.assertEqual(copy.name, 'Copied Model')
  62. self.assertEqual(copy.description, 'Description for Model 1')
  63. self.assertEqual(copy.root_folder, 'modeldir3')
  64. self.assertEqual(copy.supports, ['labeled-image', 'fit'])
  65. def test_label_providers(self):
  66. label_providers = list(self.database.label_providers())
  67. # test length
  68. self.assertEqual(len(label_providers), 2)
  69. for i in range(2):
  70. self.assertEqual(label_providers[i].id, i + 1)
  71. self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}')
  72. self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}')
  73. self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
  74. def test_projects(self):
  75. models = list(self.database.models())
  76. label_providers = list(self.database.label_providers())
  77. projects = list(self.database.projects())
  78. # create projects
  79. for i in range(3):
  80. project = projects[i]
  81. self.assertEqual(project.id, i + 1)
  82. self.assertEqual(project.name, f'Project {i + 1}')
  83. self.assertEqual(project.description, f'Project Description {i + 1}')
  84. self.assertEqual(project.model_id, i + 1)
  85. self.assertEqual(project.model.__dict__, models[i].__dict__)
  86. self.assertEqual(project.label_provider_id, label_providers[i].id if i < 2 else None)
  87. self.assertEqual(
  88. project.label_provider.__dict__ if project.label_provider is not None else None,
  89. label_providers[i].__dict__ if i < 2 else None
  90. )
  91. self.assertEqual(project.root_folder, f'projectdir{i + 1}')
  92. self.assertEqual(project.external_data, i == 1)
  93. self.assertEqual(project.data_folder, f'datadir{i + 1}')
  94. # get projects
  95. self.assertEqual(len(list(self.database.projects())), 3)
  96. # remove a project
  97. list(self.database.projects())[0].remove()
  98. projects = list(self.database.projects())
  99. self.assertEqual(len(projects), 2)
  100. self.assertEqual(projects[0].name, 'Project 2')
  101. # set properties
  102. project = list(self.database.projects())[0]
  103. project.set_name('Project 0')
  104. self.assertEqual(list(self.database.projects())[0].name, 'Project 0')
  105. project.set_description('Description 0')
  106. self.assertEqual(list(self.database.projects())[0].description, 'Description 0')
  107. def test_no_files_after_project_deletion(self):
  108. project = self.database.project(1)
  109. for i in range(5):
  110. file, is_new = project.add_file(
  111. uuid=f"some_string{i}",
  112. name=f"some_name{i}",
  113. filename=f"some_filename{i}",
  114. file_type="image",
  115. extension=".jpg",
  116. size=42,
  117. )
  118. self.assertTrue(is_new)
  119. self.assertIsNotNone(file)
  120. self.assertEqual(5, File.query.filter_by(project_id=project.id).count())
  121. project.remove()
  122. self.assertIsNone(self.database.project(1))
  123. self.assertEqual(0, File.query.filter_by(project_id=project.id).count())
  124. def test_no_labels_after_project_deletion(self):
  125. project = self.database.project(1)
  126. for i in range(5):
  127. label, is_new = project.create_label(
  128. name=f"label{i}",
  129. reference=f"ref{i}"
  130. )
  131. self.assertTrue(is_new)
  132. self.assertIsNotNone(label)
  133. self.assertEqual(5, Label.query.filter_by(project_id=project.id).count())
  134. project.remove()
  135. self.assertIsNone(self.database.project(1))
  136. self.assertEqual(0, Label.query.filter_by(project_id=project.id).count())
  137. if __name__ == '__main__':
  138. unittest.main()