test_database.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import unittest
  2. from pycs import db
  3. from pycs.database.File import File
  4. from pycs.database.Label import Label
  5. from pycs.database.LabelProvider import LabelProvider
  6. from pycs.database.Model import Model
  7. from pycs.database.Project import Project
  8. from pycs.database.Result import Result
  9. from tests.base import BaseTestCase
  10. class TestDatabase(BaseTestCase):
  11. def setupModels(self):
  12. with db.session.begin_nested():
  13. for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
  14. model = Model.new(
  15. commit=False,
  16. name=f"Model {i}",
  17. description=f"Description for Model {i}",
  18. root_folder=f"modeldir{i}",
  19. )
  20. model.supports = supports
  21. if i > 2:
  22. continue
  23. provider = LabelProvider.new(
  24. commit=False,
  25. name=f"Label Provider {i}",
  26. description=f"Description for Label Provider {i}",
  27. root_folder=f"labeldir{i}",
  28. configuration_file=f"labeldir{i}/configuration.json"
  29. )
  30. # projects
  31. models = Model.query.all()
  32. label_providers = LabelProvider.query.all()
  33. for i, model in enumerate(models, 1):
  34. Project.new(
  35. name=f'Project {i}',
  36. description=f'Project Description {i}',
  37. model=model,
  38. label_provider=label_providers[i-1] if i < 3 else None,
  39. root_folder=f'projectdir{i}',
  40. external_data=i==2,
  41. data_folder=f'datadir{i}',
  42. )
  43. def test_models(self):
  44. models = Model.query.all()
  45. # test length
  46. self.assertEqual(len(models), 3)
  47. # test insert
  48. for i in range(3):
  49. self.assertEqual(models[i].id, i + 1)
  50. self.assertEqual(models[i].name, f'Model {i + 1}')
  51. self.assertEqual(models[i].description, f'Description for Model {i + 1}')
  52. self.assertEqual(models[i].root_folder, f'modeldir{i + 1}')
  53. self.assertEqual(models[0].supports, ['labeled-image', 'fit'])
  54. self.assertEqual(models[1].supports, ['labeled-bounding-boxes'])
  55. self.assertEqual(models[2].supports, ['labeled-bounding-boxes'])
  56. # test copy
  57. copy, is_new = models[0].copy_to('Copied Model', 'some_other_dir')
  58. self.assertTrue(is_new)
  59. self.assertEqual(copy.id, 4)
  60. self.assertEqual(copy.name, 'Copied Model')
  61. self.assertEqual(copy.description, 'Description for Model 1')
  62. self.assertEqual(copy.root_folder, 'some_other_dir')
  63. self.assertEqual(copy.supports, ['labeled-image', 'fit'])
  64. def test_label_providers(self):
  65. label_providers = LabelProvider.query.all()
  66. # test length
  67. self.assertEqual(len(label_providers), 2)
  68. for i in range(2):
  69. self.assertEqual(label_providers[i].id, i + 1)
  70. self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}')
  71. self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}')
  72. self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
  73. self.assertEqual(label_providers[i].configuration_file,
  74. f"labeldir{i + 1}/configuration.json")
  75. def test_projects(self):
  76. models = Model.query.all()
  77. label_providers = LabelProvider.query.all()
  78. projects = Project.query.all()
  79. # get projects
  80. self.assertEqual(len(projects), 3)
  81. # create projects
  82. for i, project in enumerate(projects):
  83. self.assertEqual(project.id, i + 1)
  84. self.assertEqual(project.name, f'Project {i + 1}')
  85. self.assertEqual(project.description, f'Project Description {i + 1}')
  86. self.assertEqual(project.model_id, i + 1)
  87. self.assertEqual(project.model.__dict__, models[i].__dict__)
  88. self.assertEqual(project.label_provider_id, label_providers[i].id if i < 2 else None)
  89. self.assertEqual(
  90. project.label_provider.__dict__ if project.label_provider is not None else None,
  91. label_providers[i].__dict__ if i < 2 else None
  92. )
  93. self.assertEqual(project.root_folder, f'projectdir{i + 1}')
  94. self.assertEqual(project.external_data, i == 1)
  95. self.assertEqual(project.data_folder, f'datadir{i + 1}')
  96. # remove a project
  97. Project.query.first().delete()
  98. self.assertEqual(Project.query.count(), 2)
  99. self.assertEqual(Project.query.first().name, 'Project 2')
  100. # set properties
  101. project = Project.query.first()
  102. project.name = 'Project 0'
  103. project.commit()
  104. self.assertEqual(Project.query.first().name, 'Project 0')
  105. project.description = 'Description 0'
  106. project.commit()
  107. self.assertEqual(Project.query.first().description, 'Description 0')
  108. if __name__ == '__main__':
  109. unittest.main()