label_tests.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. from flask import url_for
  2. from pycs import db
  3. from pycs.database.Label import Label
  4. from pycs.database.Model import Model
  5. from pycs.database.Project import Project
  6. from tests.client.project_tests import _BaseProjectTests
  7. class _BaseLabelTests(_BaseProjectTests):
  8. def setupModels(self):
  9. super().setupModels()
  10. self.project = Project.new(
  11. name="test_project",
  12. description="Project for a test case",
  13. model=self.model,
  14. root_folder="project_folder",
  15. external_data=False,
  16. data_folder="project_folder/data",
  17. )
  18. def tearDown(self):
  19. self.project.delete()
  20. super().tearDown()
  21. class LabelCreationTests(_BaseLabelTests):
  22. def setUp(self):
  23. super().setUp()
  24. self.url = url_for("create_label", project_id=self.project.id)
  25. def test_create_labels(self):
  26. self.assertEqual(0, self.project.labels.count())
  27. self.post(self.url, json=dict(), status_code=400)
  28. self.assertEqual(0, self.project.labels.count())
  29. for i in range(1, 11):
  30. self.post(self.url, json=dict(name=f"Label{i}"))
  31. self.assertEqual(self.project.labels.count(), i)
  32. last_label = Label.query.order_by(Label.id.desc()).first()
  33. self.assertEqual(last_label.name, f"Label{i}")
  34. def test_create_label_with_same_name(self):
  35. self.assertEqual(0, self.project.labels.count())
  36. name = "SameName"
  37. for i in range(1, 11):
  38. self.post(self.url, json=dict(name=name))
  39. self.assertEqual(self.project.labels.count(), i)
  40. last_label = Label.query.order_by(Label.id.desc()).first()
  41. self.assertEqual(last_label.name, name)
  42. def test_create_labels_with_reference(self):
  43. self.assertEqual(0, self.project.labels.count())
  44. for i in range(1, 11):
  45. self.post(self.url, json=dict(name=f"Label{i}", reference=f"label_ref{i}"))
  46. self.assertEqual(self.project.labels.count(), i)
  47. last_label = Label.query.order_by(Label.id.desc()).first()
  48. self.assertEqual(last_label.name, f"Label{i}")
  49. self.assertEqual(last_label.reference, f"label_ref{i}")
  50. def test_create_labels_with_same_reference(self):
  51. self.assertEqual(0, self.project.labels.count())
  52. ref = "same_ref"
  53. self.post(self.url, json=dict(name=f"Label", reference=ref))
  54. last_label = Label.query.order_by(Label.id.desc()).first()
  55. self.assertEqual(last_label.name, f"Label")
  56. self.assertEqual(last_label.reference, ref)
  57. for i in range(2, 11):
  58. self.post(self.url, status_code=400, json=dict(name=f"Label{i}", reference=ref))
  59. self.assertEqual(self.project.labels.count(), 1)
  60. last_label = Label.query.order_by(Label.id.desc()).first()
  61. self.assertEqual(last_label.name, f"Label")
  62. self.assertEqual(last_label.reference, ref)
  63. class LabelListTests(_BaseLabelTests):
  64. def setupModels(self):
  65. super().setupModels()
  66. for i in range(1, 11):
  67. parent, is_new = self.project.create_label(name=f"Label{i}", reference=f"label{i}")
  68. self.assertTrue(is_new)
  69. for j in range(1, 4):
  70. label, is_new = self.project.create_label(
  71. name=f"Label{i}_{j}", reference=f"label{i}_{j}",
  72. parent=parent)
  73. self.assertTrue(is_new)
  74. def test_list_labels(self):
  75. self.get(url_for("list_labels", project_id=4242), status_code=404)
  76. response = self.get(url_for("list_labels", project_id=self.project.id))
  77. self.assertTrue(response.is_json)
  78. content = response.json
  79. labels = {l.id: l for l in self.project.labels.all()}
  80. self.assertEqual(len(labels), len(content))
  81. for entry in content:
  82. label = labels[entry["id"]]
  83. self.assertDictEqual(label.serialize(), entry)
  84. def test_list_label_tree(self):
  85. self.get(url_for("list_label_tree", project_id=4242), status_code=404)
  86. response = self.get(url_for("list_label_tree", project_id=self.project.id))
  87. self.assertTrue(response.is_json)
  88. content = response.json
  89. root_labels = {l.id: l for l in Label.query.filter(
  90. Label.project_id==self.project.id, Label.parent_id==None).all()}
  91. self.assertEqual(len(root_labels), len(content))
  92. for entry in content:
  93. label = root_labels[entry["id"]]
  94. self.assertDictEqual(label.serialize(), entry)
  95. class LabelRemovalTests(_BaseLabelTests):
  96. def url(self, label_id):
  97. return url_for("remove_label",
  98. project_id=self.project.id, label_id=label_id)
  99. def setupModels(self):
  100. super().setupModels()
  101. self.N = 10
  102. self.labels = []
  103. with db.session.begin_nested():
  104. for i in range(1, self.N+1):
  105. label, is_new = self.project.create_label(
  106. commit=False,
  107. name=f"Label{i}",
  108. reference=f"Label{i}",
  109. )
  110. assert is_new, "labels should be new"
  111. label.flush()
  112. self.labels.append(label)
  113. def test_remove_label_without_flag(self):
  114. self.assertEqual(self.N, self.project.labels.count())
  115. self.post(self.url(self.labels[0].id), status_code=400)
  116. self.assertEqual(self.N, self.project.labels.count())
  117. self.post(self.url(self.labels[0].id), status_code=400, json=dict(remove=False))
  118. self.assertEqual(self.N, self.project.labels.count())
  119. def test_remove_label(self):
  120. self.assertEqual(self.N, self.project.labels.count())
  121. self.post(self.url(self.labels[0].id), json=dict(remove=True))
  122. self.assertEqual(self.N-1, self.project.labels.count())
  123. def test_remove_non_existing_label(self):
  124. self.assertEqual(self.N, self.project.labels.count())
  125. self.post(self.url(424242), json=dict(remove=True), status_code=404)
  126. self.assertEqual(self.N, self.project.labels.count())
  127. class LabelTreeRemovalTests(_BaseLabelTests):
  128. def url(self, label_id):
  129. return url_for("remove_label",
  130. project_id=self.project.id, label_id=label_id)
  131. def setupModels(self):
  132. super().setupModels()
  133. self.n_roots = 5
  134. self.leafs_per_root = 3
  135. self.n_leafs = self.n_roots * self.leafs_per_root
  136. self.N = self.n_roots + self.n_leafs
  137. with db.session.begin_nested():
  138. for i in range(1, self.n_roots+1):
  139. root_label, is_new = self.project.create_label(
  140. commit=False,
  141. name=f"Label{i}",
  142. reference=f"Label{i}",
  143. )
  144. root_label.flush()
  145. for j in range(1, self.leafs_per_root+1):
  146. label, is_new = self.project.create_label(
  147. commit=False,
  148. name=f"Label{i}_{j}",
  149. reference=f"Label{i}_{j}",
  150. parent=root_label
  151. )
  152. label.flush()
  153. @property
  154. def root_labels(self):
  155. return self.project.labels.filter(Label.parent_id == None)
  156. @property
  157. def leaf_labels(self):
  158. return self.project.labels.filter(Label.parent_id != None)
  159. def test_remove_root_label(self):
  160. self.assertEqual(self.N, self.project.labels.count())
  161. self.assertEqual(self.n_roots, self.root_labels.count())
  162. self.assertEqual(self.n_leafs, self.leaf_labels.count())
  163. label = self.root_labels.first()
  164. self.post(self.url(label.id), json=dict(remove=True))
  165. self.assertEqual(self.N-1, self.project.labels.count())
  166. self.assertEqual(self.n_roots-1+self.leafs_per_root, self.root_labels.count())
  167. self.assertEqual(self.n_leafs-self.leafs_per_root, self.leaf_labels.count())
  168. def test_remove_leaf_label(self):
  169. self.assertEqual(self.N, self.project.labels.count())
  170. self.assertEqual(self.n_roots, self.root_labels.count())
  171. self.assertEqual(self.n_leafs, self.leaf_labels.count())
  172. label = self.leaf_labels.first()
  173. self.post(self.url(label.id), json=dict(remove=True))
  174. self.assertEqual(self.N-1, self.project.labels.count())
  175. self.assertEqual(self.n_roots, self.root_labels.count())
  176. self.assertEqual(self.n_leafs-1, self.leaf_labels.count())
  177. def test_remove_intermediate_label(self):
  178. self.assertEqual(self.N, self.project.labels.count())
  179. self.assertEqual(self.n_roots, self.root_labels.count())
  180. self.assertEqual(self.n_leafs, self.leaf_labels.count())
  181. label = self.leaf_labels.first()
  182. label_parent = label.parent
  183. new_label, is_new = self.project.create_label(
  184. commit=False,
  185. name="New Label",
  186. reference="new_ref",
  187. parent=label,
  188. )
  189. self.assertTrue(is_new)
  190. self.assertEqual(label.id, new_label.parent_id)
  191. self.assertEqual(self.n_roots, self.root_labels.count())
  192. self.assertEqual(self.n_leafs + 1, self.leaf_labels.count())
  193. self.post(self.url(label.id), json=dict(remove=True))
  194. self.assertEqual(label_parent.id, new_label.parent_id)
  195. self.assertEqual(self.n_roots, self.root_labels.count())
  196. self.assertEqual(self.n_leafs, self.leaf_labels.count())
  197. class LabelEditTests(_BaseLabelTests):
  198. def test_edit_name(self):
  199. self.assertEqual(0, self.project.labels.count())
  200. label = self.project.create_label(name="Label", reference="label")
  201. self.assertEqual(1, self.project.labels.count())
  202. label = self.project.labels.first()
  203. self.assertEqual("Label", label.name)
  204. self.post(self.url(label.id), json=dict(), status_code=400)
  205. name = "Another name"
  206. self.post(self.url(label.id), json=dict(name=name))
  207. label = self.project.labels.first()
  208. self.assertEqual(name, label.name)
  209. def test_edit_name(self):
  210. self.assertEqual(0, self.project.labels.count())
  211. label = self.project.create_label(name="Label", reference="label")
  212. self.assertEqual(1, self.project.labels.count())
  213. label = self.project.labels.first()
  214. self.assertEqual("Label", label.name)
  215. url = lambda label_id: url_for("edit_label_name",
  216. project_id=self.project.id, label_id=label_id)
  217. self.post(url(label.id), json=dict(), status_code=400)
  218. name = "Another name"
  219. self.post(url(4242), json=dict(name=name), status_code=404)
  220. self.post(url(label.id), json=dict(name=name))
  221. label = self.project.labels.first()
  222. self.assertEqual(name, label.name)
  223. def test_edit_parent(self):
  224. self.assertEqual(0, self.project.labels.count())
  225. label, _ = self.project.create_label(name="Label", reference="label")
  226. parent, _ = self.project.create_label(name="Label2", reference="label2")
  227. self.assertEqual(2, self.project.labels.count())
  228. label = Label.query.get(label.id)
  229. self.assertIsNone(label.parent_id)
  230. url = lambda label_id: url_for("edit_label_parent",
  231. project_id=self.project.id, label_id=label_id)
  232. self.post(url(label.id), json=dict(), status_code=400)
  233. self.post(url(4242), json=dict(parent=parent.id), status_code=404)
  234. self.post(url(label.id), json=dict(parent=parent.id))
  235. label = Label.query.get(label.id)
  236. self.assertEqual(parent.id, label.parent_id)