6
0

label_tests.py 7.7 KB


  1. from flask import url_for
  2. from pycs import db
  3. from pycs.database.Label import Label
  4. from pycs.database.Model import Model
  5. from pycs.database.Project import Project
  6. from tests.base import BaseTestCase
  7. class _BaseLabelTests(BaseTestCase):
  8. def setupModels(self):
  9. model = Model.new(
  10. commit=False,
  11. name="TestModel",
  12. description="Model for a test case",
  13. root_folder="model_folder",
  14. )
  15. model.supports = ["labeled-image"]
  16. model.flush()
  17. self.project = Project.new(
  18. name="test_project",
  19. description="Project for a test case",
  20. model=model,
  21. root_folder="project_folder",
  22. external_data=False,
  23. data_folder="project_folder/data",
  24. )
  25. class LabelCreationTests(_BaseLabelTests):
  26. def setUp(self):
  27. super().setUp()
  28. self.url = url_for("create_label", project_id=self.project.id)
  29. def test_create_labels(self):
  30. self.assertEqual(0, self.project.labels.count())
  31. for i in range(1, 11):
  32. self.post(self.url, json=dict(name=f"Label{i}"))
  33. self.assertEqual(self.project.labels.count(), i)
  34. last_label = Label.query.order_by(Label.id.desc()).first()
  35. self.assertEqual(last_label.name, f"Label{i}")
  36. def test_create_label_with_same_name(self):
  37. self.assertEqual(0, self.project.labels.count())
  38. name = "SameName"
  39. for i in range(1, 11):
  40. self.post(self.url, json=dict(name=name))
  41. self.assertEqual(self.project.labels.count(), i)
  42. last_label = Label.query.order_by(Label.id.desc()).first()
  43. self.assertEqual(last_label.name, name)
  44. def test_create_labels_with_reference(self):
  45. self.assertEqual(0, self.project.labels.count())
  46. for i in range(1, 11):
  47. self.post(self.url, json=dict(name=f"Label{i}", reference=f"label_ref{i}"))
  48. self.assertEqual(self.project.labels.count(), i)
  49. last_label = Label.query.order_by(Label.id.desc()).first()
  50. self.assertEqual(last_label.name, f"Label{i}")
  51. self.assertEqual(last_label.reference, f"label_ref{i}")
  52. def test_create_labels_with_same_reference(self):
  53. self.assertEqual(0, self.project.labels.count())
  54. ref = "same_ref"
  55. self.post(self.url, json=dict(name=f"Label", reference=ref))
  56. last_label = Label.query.order_by(Label.id.desc()).first()
  57. self.assertEqual(last_label.name, f"Label")
  58. self.assertEqual(last_label.reference, ref)
  59. for i in range(2, 11):
  60. self.post(self.url, status_code=400, json=dict(name=f"Label{i}", reference=ref))
  61. self.assertEqual(self.project.labels.count(), 1)
  62. last_label = Label.query.order_by(Label.id.desc()).first()
  63. self.assertEqual(last_label.name, f"Label")
  64. self.assertEqual(last_label.reference, ref)
  65. class LabelRemovalTests(_BaseLabelTests):
  66. def url(self, label_id):
  67. return url_for("remove_label",
  68. project_id=self.project.id, label_id=label_id)
  69. def setupModels(self):
  70. super().setupModels()
  71. self.N = 10
  72. self.labels = []
  73. with db.session.begin_nested():
  74. for i in range(1, self.N+1):
  75. label, is_new = self.project.create_label(
  76. commit=False,
  77. name=f"Label{i}",
  78. reference=f"Label{i}",
  79. )
  80. assert is_new, "labels should be new"
  81. label.flush()
  82. self.labels.append(label)
  83. def test_remove_label_without_flag(self):
  84. self.assertEqual(self.N, self.project.labels.count())
  85. self.post(self.url(self.labels[0].id), status_code=400)
  86. self.assertEqual(self.N, self.project.labels.count())
  87. self.post(self.url(self.labels[0].id), status_code=400, json=dict(remove=False))
  88. self.assertEqual(self.N, self.project.labels.count())
  89. def test_remove_label(self):
  90. self.assertEqual(self.N, self.project.labels.count())
  91. self.post(self.url(self.labels[0].id), json=dict(remove=True))
  92. self.assertEqual(self.N-1, self.project.labels.count())
  93. def test_remove_non_existing_label(self):
  94. self.assertEqual(self.N, self.project.labels.count())
  95. self.post(self.url(424242), json=dict(remove=True), status_code=404)
  96. self.assertEqual(self.N, self.project.labels.count())
  97. class LabelTreeRemovalTests(_BaseLabelTests):
  98. def url(self, label_id):
  99. return url_for("remove_label",
  100. project_id=self.project.id, label_id=label_id)
  101. def setupModels(self):
  102. super().setupModels()
  103. self.n_roots = 5
  104. self.leafs_per_root = 3
  105. self.n_leafs = self.n_roots * self.leafs_per_root
  106. self.N = self.n_roots + self.n_leafs
  107. with db.session.begin_nested():
  108. for i in range(1, self.n_roots+1):
  109. root_label, is_new = self.project.create_label(
  110. commit=False,
  111. name=f"Label{i}",
  112. reference=f"Label{i}",
  113. )
  114. root_label.flush()
  115. for j in range(1, self.leafs_per_root+1):
  116. label, is_new = self.project.create_label(
  117. commit=False,
  118. name=f"Label{i}_{j}",
  119. reference=f"Label{i}_{j}",
  120. parent=root_label
  121. )
  122. label.flush()
  123. @property
  124. def root_labels(self):
  125. return self.project.labels.filter(Label.parent_id == None)
  126. @property
  127. def leaf_labels(self):
  128. return self.project.labels.filter(Label.parent_id != None)
  129. def test_remove_root_label(self):
  130. self.assertEqual(self.N, self.project.labels.count())
  131. self.assertEqual(self.n_roots, self.root_labels.count())
  132. self.assertEqual(self.n_leafs, self.leaf_labels.count())
  133. label = self.root_labels.first()
  134. self.post(self.url(label.id), json=dict(remove=True))
  135. self.assertEqual(self.N-1, self.project.labels.count())
  136. self.assertEqual(self.n_roots-1+self.leafs_per_root, self.root_labels.count())
  137. self.assertEqual(self.n_leafs-self.leafs_per_root, self.leaf_labels.count())
  138. def test_remove_leaf_label(self):
  139. self.assertEqual(self.N, self.project.labels.count())
  140. self.assertEqual(self.n_roots, self.root_labels.count())
  141. self.assertEqual(self.n_leafs, self.leaf_labels.count())
  142. label = self.leaf_labels.first()
  143. self.post(self.url(label.id), json=dict(remove=True))
  144. self.assertEqual(self.N-1, self.project.labels.count())
  145. self.assertEqual(self.n_roots, self.root_labels.count())
  146. self.assertEqual(self.n_leafs-1, self.leaf_labels.count())
  147. def test_remove_intermediate_label(self):
  148. self.assertEqual(self.N, self.project.labels.count())
  149. self.assertEqual(self.n_roots, self.root_labels.count())
  150. self.assertEqual(self.n_leafs, self.leaf_labels.count())
  151. label = self.leaf_labels.first()
  152. label_parent = label.parent
  153. new_label, is_new = self.project.create_label(
  154. commit=False,
  155. name="New Label",
  156. reference="new_ref",
  157. parent=label,
  158. )
  159. self.assertTrue(is_new)
  160. self.assertEqual(label.id, new_label.parent_id)
  161. self.assertEqual(self.n_roots, self.root_labels.count())
  162. self.assertEqual(self.n_leafs + 1, self.leaf_labels.count())
  163. self.post(self.url(label.id), json=dict(remove=True))
  164. self.assertEqual(label_parent.id, new_label.parent_id)
  165. self.assertEqual(self.n_roots, self.root_labels.count())
  166. self.assertEqual(self.n_leafs, self.leaf_labels.count())