test_database.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import unittest
  2. from pycs import db
  3. from pycs.database.File import File
  4. from pycs.database.Label import Label
  5. from pycs.database.LabelProvider import LabelProvider
  6. from pycs.database.Model import Model
  7. from pycs.database.Project import Project
  8. from pycs.database.Result import Result
  9. from test.base import BaseTestCase
  10. class TestDatabase(BaseTestCase):
  11. def setUp(self) -> None:
  12. super().setUp(discovery=False)
  13. with db.session.begin_nested():
  14. for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
  15. model = Model.new(
  16. commit=False,
  17. name=f"Model {i}",
  18. description=f"Description for Model {i}",
  19. root_folder=f"modeldir{i}",
  20. )
  21. model.supports = supports
  22. if i > 2:
  23. continue
  24. provider = LabelProvider.new(
  25. commit=False,
  26. name=f"Label Provider {i}",
  27. description=f"Description for Label Provider {i}",
  28. root_folder=f"labeldir{i}",
  29. configuration_file=f"labeldir{i}/configuration.json"
  30. )
  31. # projects
  32. models = Model.query.all()
  33. label_providers = LabelProvider.query.all()
  34. for i, model in enumerate(models, 1):
  35. Project.new(
  36. name=f'Project {i}',
  37. description=f'Project Description {i}',
  38. model=model,
  39. label_provider=label_providers[i-1] if i < 3 else None,
  40. root_folder=f'projectdir{i}',
  41. external_data=i==2,
  42. data_folder=f'datadir{i}',
  43. )
  44. def test_models(self):
  45. models = Model.query.all()
  46. # test length
  47. self.assertEqual(len(models), 3)
  48. # test insert
  49. for i in range(3):
  50. self.assertEqual(models[i].id, i + 1)
  51. self.assertEqual(models[i].name, f'Model {i + 1}')
  52. self.assertEqual(models[i].description, f'Description for Model {i + 1}')
  53. self.assertEqual(models[i].root_folder, f'modeldir{i + 1}')
  54. self.assertEqual(models[0].supports, ['labeled-image', 'fit'])
  55. self.assertEqual(models[1].supports, ['labeled-bounding-boxes'])
  56. self.assertEqual(models[2].supports, ['labeled-bounding-boxes'])
  57. # test copy
  58. copy, is_new = models[0].copy_to('Copied Model', 'some_other_dir')
  59. self.assertTrue(is_new)
  60. self.assertEqual(copy.id, 4)
  61. self.assertEqual(copy.name, 'Copied Model')
  62. self.assertEqual(copy.description, 'Description for Model 1')
  63. self.assertEqual(copy.root_folder, 'some_other_dir')
  64. self.assertEqual(copy.supports, ['labeled-image', 'fit'])
  65. def test_label_providers(self):
  66. label_providers = LabelProvider.query.all()
  67. # test length
  68. self.assertEqual(len(label_providers), 2)
  69. for i in range(2):
  70. self.assertEqual(label_providers[i].id, i + 1)
  71. self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}')
  72. self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}')
  73. self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
  74. self.assertEqual(label_providers[i].configuration_file,
  75. f"labeldir{i + 1}/configuration.json")
  76. def test_projects(self):
  77. models = Model.query.all()
  78. label_providers = LabelProvider.query.all()
  79. projects = Project.query.all()
  80. # get projects
  81. self.assertEqual(len(projects), 3)
  82. # create projects
  83. for i, project in enumerate(projects):
  84. self.assertEqual(project.id, i + 1)
  85. self.assertEqual(project.name, f'Project {i + 1}')
  86. self.assertEqual(project.description, f'Project Description {i + 1}')
  87. self.assertEqual(project.model_id, i + 1)
  88. self.assertEqual(project.model.__dict__, models[i].__dict__)
  89. self.assertEqual(project.label_provider_id, label_providers[i].id if i < 2 else None)
  90. self.assertEqual(
  91. project.label_provider.__dict__ if project.label_provider is not None else None,
  92. label_providers[i].__dict__ if i < 2 else None
  93. )
  94. self.assertEqual(project.root_folder, f'projectdir{i + 1}')
  95. self.assertEqual(project.external_data, i == 1)
  96. self.assertEqual(project.data_folder, f'datadir{i + 1}')
  97. # remove a project
  98. Project.query.first().delete()
  99. self.assertEqual(Project.query.count(), 2)
  100. self.assertEqual(Project.query.first().name, 'Project 2')
  101. # set properties
  102. project = Project.query.first()
  103. project.name = 'Project 0'
  104. project.commit()
  105. self.assertEqual(Project.query.first().name, 'Project 0')
  106. project.description = 'Description 0'
  107. project.commit()
  108. self.assertEqual(Project.query.first().description, 'Description 0')
  109. if __name__ == '__main__':
  110. unittest.main()