test_database.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import unittest
  2. from contextlib import closing
  3. from pycs import db
  4. from pycs.database.Database import Database
  5. from pycs.database.Model import Model
  6. from pycs.database.LabelProvider import LabelProvider
  7. class TestDatabase(unittest.TestCase):
  8. def setUp(self) -> None:
  9. db.create_all()
  10. # create database
  11. self.database = Database(discovery=False)
  12. # insert default models and label_providers
  13. with self.database:
  14. for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
  15. model = Model.new(
  16. name=f"Model {i}",
  17. description=f"Description for Model {i}",
  18. root_folder=f"modeldir{i}",
  19. )
  20. model.supports = supports
  21. if i > 2:
  22. continue
  23. provider = LabelProvider.new(
  24. name=f"Label Provider {i}",
  25. description=f"Description for Label Provider {i}",
  26. root_folder=f"labeldir{i}",
  27. )
  28. # projects
  29. models = list(self.database.models())
  30. label_providers = list(self.database.label_providers())
  31. for i, model in enumerate(models, 1):
  32. self.database.create_project(
  33. name=f'Project {i}',
  34. description=f'Project Description {i}',
  35. model=model,
  36. label_provider=label_providers[i-1] if i < 3 else None,
  37. root_folder=f'projectdir{i}',
  38. external_data=i==2,
  39. data_folder=f'datadir{i}',
  40. )
  41. def tearDown(self) -> None:
  42. db.drop_all()
  43. self.database.close()
  44. def test_models(self):
  45. models = list(self.database.models())
  46. # test length
  47. self.assertEqual(len(models), 3)
  48. # test insert
  49. for i in range(2):
  50. self.assertEqual(models[i].id, i + 1)
  51. self.assertEqual(models[i].name, f'Model {i + 1}')
  52. self.assertEqual(models[i].description, f'Description for Model {i + 1}')
  53. self.assertEqual(models[i].root_folder, f'modeldir{i + 1}')
  54. self.assertEqual(models[0].supports, ['labeled-image', 'fit'])
  55. self.assertEqual(models[1].supports, ['labeled-bounding-boxes'])
  56. # test copy
  57. copy, _ = models[0].copy_to('Copied Model', 'modeldir3')
  58. self.assertEqual(copy.id, 3)
  59. self.assertEqual(copy.name, 'Copied Model')
  60. self.assertEqual(copy.description, 'Description for Model 1')
  61. self.assertEqual(copy.root_folder, 'modeldir3')
  62. self.assertEqual(copy.supports, ['labeled-image', 'fit'])
  63. def test_label_providers(self):
  64. label_providers = list(self.database.label_providers())
  65. # test length
  66. self.assertEqual(len(label_providers), 2)
  67. for i in range(2):
  68. self.assertEqual(label_providers[i].id, i + 1)
  69. self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}')
  70. self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}')
  71. self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
  72. def test_projects(self):
  73. models = list(self.database.models())
  74. label_providers = list(self.database.label_providers())
  75. projects = list(self.database.projects())
  76. # create projects
  77. for i in range(3):
  78. project = projects[i]
  79. self.assertEqual(project.id, i + 1)
  80. self.assertEqual(project.name, f'Project {i + 1}')
  81. self.assertEqual(project.description, f'Project Description {i + 1}')
  82. self.assertEqual(project.model_id, i + 1)
  83. self.assertEqual(project.model.__dict__, models[i].__dict__)
  84. self.assertEqual(project.label_provider_id, label_providers[i].id if i < 2 else None)
  85. self.assertEqual(
  86. project.label_provider.__dict__ if project.label_provider is not None else None,
  87. label_providers[i].__dict__ if i < 2 else None
  88. )
  89. self.assertEqual(project.root_folder, f'projectdir{i + 1}')
  90. self.assertEqual(project.external_data, i == 1)
  91. self.assertEqual(project.data_folder, f'datadir{i + 1}')
  92. # get projects
  93. self.assertEqual(len(list(self.database.projects())), 3)
  94. # remove a project
  95. list(self.database.projects())[0].remove()
  96. projects = list(self.database.projects())
  97. self.assertEqual(len(projects), 2)
  98. self.assertEqual(projects[0].name, 'Project 2')
  99. # set properties
  100. project = list(self.database.projects())[0]
  101. project.set_name('Project 0')
  102. self.assertEqual(list(self.database.projects())[0].name, 'Project 0')
  103. project.set_description('Description 0')
  104. self.assertEqual(list(self.database.projects())[0].description, 'Description 0')
  105. if __name__ == '__main__':
  106. unittest.main()