6
0

label_tests.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. from flask import url_for
  2. from pycs import db
  3. from pycs.database.Label import Label
  4. from pycs.database.Model import Model
  5. from pycs.database.Project import Project
  6. from tests.base import BaseTestCase
  7. class _BaseLabelTests(BaseTestCase):
  8. def setupModels(self):
  9. model = Model.new(
  10. commit=False,
  11. name="TestModel",
  12. description="Model for a test case",
  13. root_folder="model_folder",
  14. )
  15. model.supports = ["labeled-image"]
  16. model.flush()
  17. self.project = Project.new(
  18. name="test_project",
  19. description="Project for a test case",
  20. model=model,
  21. root_folder="project_folder",
  22. external_data=False,
  23. data_folder="project_folder/data",
  24. )
  25. class LabelCreationTests(_BaseLabelTests):
  26. def setUp(self):
  27. super().setUp()
  28. self.url = url_for("create_label", project_id=self.project.id)
  29. def test_create_labels(self):
  30. self.assertEqual(0, self.project.labels.count())
  31. self.post(self.url, json=dict(), status_code=400)
  32. self.assertEqual(0, self.project.labels.count())
  33. for i in range(1, 11):
  34. self.post(self.url, json=dict(name=f"Label{i}"))
  35. self.assertEqual(self.project.labels.count(), i)
  36. last_label = Label.query.order_by(Label.id.desc()).first()
  37. self.assertEqual(last_label.name, f"Label{i}")
  38. def test_create_label_with_same_name(self):
  39. self.assertEqual(0, self.project.labels.count())
  40. name = "SameName"
  41. for i in range(1, 11):
  42. self.post(self.url, json=dict(name=name))
  43. self.assertEqual(self.project.labels.count(), i)
  44. last_label = Label.query.order_by(Label.id.desc()).first()
  45. self.assertEqual(last_label.name, name)
  46. def test_create_labels_with_reference(self):
  47. self.assertEqual(0, self.project.labels.count())
  48. for i in range(1, 11):
  49. self.post(self.url, json=dict(name=f"Label{i}", reference=f"label_ref{i}"))
  50. self.assertEqual(self.project.labels.count(), i)
  51. last_label = Label.query.order_by(Label.id.desc()).first()
  52. self.assertEqual(last_label.name, f"Label{i}")
  53. self.assertEqual(last_label.reference, f"label_ref{i}")
  54. def test_create_labels_with_same_reference(self):
  55. self.assertEqual(0, self.project.labels.count())
  56. ref = "same_ref"
  57. self.post(self.url, json=dict(name=f"Label", reference=ref))
  58. last_label = Label.query.order_by(Label.id.desc()).first()
  59. self.assertEqual(last_label.name, f"Label")
  60. self.assertEqual(last_label.reference, ref)
  61. for i in range(2, 11):
  62. self.post(self.url, status_code=400, json=dict(name=f"Label{i}", reference=ref))
  63. self.assertEqual(self.project.labels.count(), 1)
  64. last_label = Label.query.order_by(Label.id.desc()).first()
  65. self.assertEqual(last_label.name, f"Label")
  66. self.assertEqual(last_label.reference, ref)
  67. class LabelListTests(_BaseLabelTests):
  68. def setupModels(self):
  69. super().setupModels()
  70. for i in range(1, 11):
  71. parent, is_new = self.project.create_label(name=f"Label{i}", reference=f"label{i}")
  72. self.assertTrue(is_new)
  73. for j in range(1, 4):
  74. label, is_new = self.project.create_label(
  75. name=f"Label{i}_{j}", reference=f"label{i}_{j}",
  76. parent=parent)
  77. self.assertTrue(is_new)
  78. def test_list_labels(self):
  79. self.get(url_for("list_labels", project_id=4242), status_code=404)
  80. response = self.get(url_for("list_labels", project_id=self.project.id))
  81. self.assertTrue(response.is_json)
  82. content = response.json
  83. labels = {l.id: l for l in self.project.labels.all()}
  84. self.assertEqual(len(labels), len(content))
  85. for entry in content:
  86. label = labels[entry["id"]]
  87. self.assertDictEqual(label.serialize(), entry)
  88. def test_list_label_tree(self):
  89. self.get(url_for("list_label_tree", project_id=4242), status_code=404)
  90. response = self.get(url_for("list_label_tree", project_id=self.project.id))
  91. self.assertTrue(response.is_json)
  92. content = response.json
  93. root_labels = {l.id: l for l in Label.query.filter(
  94. Label.project_id==self.project.id, Label.parent_id==None).all()}
  95. self.assertEqual(len(root_labels), len(content))
  96. for entry in content:
  97. label = root_labels[entry["id"]]
  98. self.assertDictEqual(label.serialize(), entry)
  99. class LabelRemovalTests(_BaseLabelTests):
  100. def url(self, label_id):
  101. return url_for("remove_label",
  102. project_id=self.project.id, label_id=label_id)
  103. def setupModels(self):
  104. super().setupModels()
  105. self.N = 10
  106. self.labels = []
  107. with db.session.begin_nested():
  108. for i in range(1, self.N+1):
  109. label, is_new = self.project.create_label(
  110. commit=False,
  111. name=f"Label{i}",
  112. reference=f"Label{i}",
  113. )
  114. assert is_new, "labels should be new"
  115. label.flush()
  116. self.labels.append(label)
  117. def test_remove_label_without_flag(self):
  118. self.assertEqual(self.N, self.project.labels.count())
  119. self.post(self.url(self.labels[0].id), status_code=400)
  120. self.assertEqual(self.N, self.project.labels.count())
  121. self.post(self.url(self.labels[0].id), status_code=400, json=dict(remove=False))
  122. self.assertEqual(self.N, self.project.labels.count())
  123. def test_remove_label(self):
  124. self.assertEqual(self.N, self.project.labels.count())
  125. self.post(self.url(self.labels[0].id), json=dict(remove=True))
  126. self.assertEqual(self.N-1, self.project.labels.count())
  127. def test_remove_non_existing_label(self):
  128. self.assertEqual(self.N, self.project.labels.count())
  129. self.post(self.url(424242), json=dict(remove=True), status_code=404)
  130. self.assertEqual(self.N, self.project.labels.count())
  131. class LabelTreeRemovalTests(_BaseLabelTests):
  132. def url(self, label_id):
  133. return url_for("remove_label",
  134. project_id=self.project.id, label_id=label_id)
  135. def setupModels(self):
  136. super().setupModels()
  137. self.n_roots = 5
  138. self.leafs_per_root = 3
  139. self.n_leafs = self.n_roots * self.leafs_per_root
  140. self.N = self.n_roots + self.n_leafs
  141. with db.session.begin_nested():
  142. for i in range(1, self.n_roots+1):
  143. root_label, is_new = self.project.create_label(
  144. commit=False,
  145. name=f"Label{i}",
  146. reference=f"Label{i}",
  147. )
  148. root_label.flush()
  149. for j in range(1, self.leafs_per_root+1):
  150. label, is_new = self.project.create_label(
  151. commit=False,
  152. name=f"Label{i}_{j}",
  153. reference=f"Label{i}_{j}",
  154. parent=root_label
  155. )
  156. label.flush()
  157. @property
  158. def root_labels(self):
  159. return self.project.labels.filter(Label.parent_id == None)
  160. @property
  161. def leaf_labels(self):
  162. return self.project.labels.filter(Label.parent_id != None)
  163. def test_remove_root_label(self):
  164. self.assertEqual(self.N, self.project.labels.count())
  165. self.assertEqual(self.n_roots, self.root_labels.count())
  166. self.assertEqual(self.n_leafs, self.leaf_labels.count())
  167. label = self.root_labels.first()
  168. self.post(self.url(label.id), json=dict(remove=True))
  169. self.assertEqual(self.N-1, self.project.labels.count())
  170. self.assertEqual(self.n_roots-1+self.leafs_per_root, self.root_labels.count())
  171. self.assertEqual(self.n_leafs-self.leafs_per_root, self.leaf_labels.count())
  172. def test_remove_leaf_label(self):
  173. self.assertEqual(self.N, self.project.labels.count())
  174. self.assertEqual(self.n_roots, self.root_labels.count())
  175. self.assertEqual(self.n_leafs, self.leaf_labels.count())
  176. label = self.leaf_labels.first()
  177. self.post(self.url(label.id), json=dict(remove=True))
  178. self.assertEqual(self.N-1, self.project.labels.count())
  179. self.assertEqual(self.n_roots, self.root_labels.count())
  180. self.assertEqual(self.n_leafs-1, self.leaf_labels.count())
  181. def test_remove_intermediate_label(self):
  182. self.assertEqual(self.N, self.project.labels.count())
  183. self.assertEqual(self.n_roots, self.root_labels.count())
  184. self.assertEqual(self.n_leafs, self.leaf_labels.count())
  185. label = self.leaf_labels.first()
  186. label_parent = label.parent
  187. new_label, is_new = self.project.create_label(
  188. commit=False,
  189. name="New Label",
  190. reference="new_ref",
  191. parent=label,
  192. )
  193. self.assertTrue(is_new)
  194. self.assertEqual(label.id, new_label.parent_id)
  195. self.assertEqual(self.n_roots, self.root_labels.count())
  196. self.assertEqual(self.n_leafs + 1, self.leaf_labels.count())
  197. self.post(self.url(label.id), json=dict(remove=True))
  198. self.assertEqual(label_parent.id, new_label.parent_id)
  199. self.assertEqual(self.n_roots, self.root_labels.count())
  200. self.assertEqual(self.n_leafs, self.leaf_labels.count())
  201. class LabelEditTests(_BaseLabelTests):
  202. def test_edit_name(self):
  203. self.assertEqual(0, self.project.labels.count())
  204. label = self.project.create_label(name="Label", reference="label")
  205. self.assertEqual(1, self.project.labels.count())
  206. label = self.project.labels.first()
  207. self.assertEqual("Label", label.name)
  208. self.post(self.url(label.id), json=dict(), status_code=400)
  209. name = "Another name"
  210. self.post(self.url(label.id), json=dict(name=name))
  211. label = self.project.labels.first()
  212. self.assertEqual(name, label.name)
  213. def test_edit_name(self):
  214. self.assertEqual(0, self.project.labels.count())
  215. label = self.project.create_label(name="Label", reference="label")
  216. self.assertEqual(1, self.project.labels.count())
  217. label = self.project.labels.first()
  218. self.assertEqual("Label", label.name)
  219. url = lambda label_id: url_for("edit_label_name",
  220. project_id=self.project.id, label_id=label_id)
  221. self.post(url(label.id), json=dict(), status_code=400)
  222. name = "Another name"
  223. self.post(url(4242), json=dict(name=name), status_code=404)
  224. self.post(url(label.id), json=dict(name=name))
  225. label = self.project.labels.first()
  226. self.assertEqual(name, label.name)
  227. def test_edit_parent(self):
  228. self.assertEqual(0, self.project.labels.count())
  229. label, _ = self.project.create_label(name="Label", reference="label")
  230. parent, _ = self.project.create_label(name="Label2", reference="label2")
  231. self.assertEqual(2, self.project.labels.count())
  232. label = Label.query.get(label.id)
  233. self.assertIsNone(label.parent_id)
  234. url = lambda label_id: url_for("edit_label_parent",
  235. project_id=self.project.id, label_id=label_id)
  236. self.post(url(label.id), json=dict(), status_code=400)
  237. self.post(url(4242), json=dict(parent=parent.id), status_code=404)
  238. self.post(url(label.id), json=dict(parent=parent.id))
  239. label = Label.query.get(label.id)
  240. self.assertEqual(parent.id, label.parent_id)