result_tests.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. import uuid
  2. from flask import url_for
  3. from pycs.database.Label import Label
  4. from pycs.database.Result import Result
  5. from tests.client.file_tests import _BaseFileTests
  6. class _BaseResultTests(_BaseFileTests):
  7. def setupModels(self):
  8. super().setupModels()
  9. file_uuid = str(uuid.uuid1())
  10. self.file, is_new = self.project.add_file(
  11. uuid=file_uuid,
  12. file_type="image",
  13. name=f"name",
  14. filename=f"image",
  15. extension=".jpg",
  16. size=32*1024,
  17. )
  18. assert is_new, "the created file should be new!"
  19. class ResultCreationTests(_BaseResultTests):
  20. def test_missing_file(self):
  21. url = url_for("create_result", file_id=4242)
  22. self.post(url, status_code=404)
  23. def test_missing_flags(self):
  24. url = url_for("create_result", file_id=self.file.id)
  25. bad_data = [
  26. None, # no request data at all
  27. dict(), # type missing
  28. dict(type="something"), # should be "labeled-image" or "bounding-box"
  29. dict(type="labeled_image"), # should be with "-"
  30. dict(type="bounding_box"), # should be with "-"
  31. dict(type="labeled-image"), # label is missing
  32. dict(type="bounding-box"), # data is missing
  33. ]
  34. for data in bad_data:
  35. self.assertEqual(0, Result.query.count())
  36. self.post(url, status_code=400, json=data)
  37. self.assertEqual(0, Result.query.count())
  38. def test_file_label(self):
  39. url = url_for("create_result", file_id=self.file.id)
  40. label, is_new = self.project.create_label(name="label", reference="some_label")
  41. self.assertTrue(is_new)
  42. self.assertEqual(0, Result.query.count())
  43. self.post(url, json=dict(type="labeled-image", label=label.id))
  44. self.assertEqual(1, Result.query.count())
  45. result = Result.query.one_or_none()
  46. self.assertIsNotNone(result)
  47. self.assertEqual("user", result.origin)
  48. self.assertEqual(self.file.id, result.file_id)
  49. self.assertEqual(label.id, result.label_id)
  50. self.assertEqual(label.name, result.label.name)
  51. self.assertIsNone(result.data_encoded)
  52. self.assertIsNone(result.data)
  53. def test_bounding_box(self):
  54. url = url_for("create_result", file_id=self.file.id)
  55. self.assertEqual(0, Result.query.count())
  56. box = dict(x=0, y=0.5, w=1/3, h=1/4)
  57. self.post(url, json=dict(type="bounding-box", data=box))
  58. self.assertEqual(1, Result.query.count())
  59. result = Result.query.one_or_none()
  60. self.assertIsNotNone(result)
  61. self.assertEqual("user", result.origin)
  62. self.assertEqual(self.file.id, result.file_id)
  63. self.assertIsNotNone(result.data_encoded)
  64. self.assertDictEqual(box, result.data)
  65. self.assertIsNone(result.label_id)
  66. class ResultGettingTests(_BaseResultTests):
  67. def test_missing_file(self):
  68. url = url_for("get_results", file_id=4242)
  69. self.get(url, status_code=404)
  70. def test_getting_of_results(self):
  71. n = 5
  72. self.assertEqual(0, Result.query.count())
  73. results = {}
  74. for i in range(n):
  75. box = dict(x=0, y=0, w=0.9, h=1.0)
  76. res = self.file.create_result("user", "bounding-box", data=box)
  77. results[res.id] = res
  78. self.assertEqual(5, Result.query.count())
  79. file_uuid = str(uuid.uuid1())
  80. another_file, is_new = self.project.add_file(
  81. uuid=file_uuid,
  82. file_type="image",
  83. name=f"name2",
  84. filename=f"image2",
  85. extension=".jpg",
  86. size=32*1024,
  87. )
  88. self.assertTrue(is_new)
  89. for i in range(n):
  90. box = dict(x=0, y=0, w=0.9, h=1.0)
  91. another_file.create_result("user", "bounding-box", data=box)
  92. self.assertEqual(10, Result.query.count())
  93. url = url_for("get_results", file_id=self.file.id)
  94. response = self.get(url)
  95. self.assertTrue(response.is_json)
  96. content = response.json
  97. self.assertEqual(5, len(content))
  98. for entry in content:
  99. res = results[entry["id"]]
  100. self.assertDictEqual(res.serialize(), entry)
  101. class ResultEditingTests(_BaseResultTests):
  102. def test_edit_data(self):
  103. self.post(url_for("edit_result_data", result_id=4242), status_code=404)
  104. box0 = dict(x=0, y=0, w=0.9, h=1.0)
  105. result = self.file.create_result("pipeline", "bounding-box", data=box0)
  106. self.assertDictEqual(box0, result.data)
  107. self.assertEqual("pipeline", result.origin)
  108. url = url_for("edit_result_data", result_id=result.id)
  109. for data in [None, dict()]:
  110. self.post(url, status_code=400, json=data)
  111. box1 = dict(x=0, y=0, w=0.9, h=1.0)
  112. self.post(url, json=dict(data=box1))
  113. self.assertDictEqual(box1, result.data)
  114. self.assertEqual("user", result.origin)
  115. def test_edit_label(self):
  116. self.post(url_for("edit_result_label", result_id=4242), status_code=404)
  117. label1, is_new = self.project.create_label(name="label1", reference="some_label1")
  118. self.assertTrue(is_new)
  119. label2, is_new = self.project.create_label(name="label2", reference="some_label2")
  120. self.assertTrue(is_new)
  121. for result_type in ["labeled-image", "bounding-box"]:
  122. result = self.file.create_result("pipeline", result_type, label=label1)
  123. self.assertEqual(label1.id, result.label_id)
  124. self.assertEqual(label1.name, result.label.name)
  125. self.assertEqual("pipeline", result.origin)
  126. url = url_for("edit_result_label", result_id=result.id)
  127. for data in [None, dict()]:
  128. self.post(url, status_code=400, json=data)
  129. if result_type == "labeled-image":
  130. self.post(url, status_code=400, json=dict(label=None))
  131. self.post(url, json=dict(label=label2.id))
  132. self.assertEqual(label2.id, result.label_id)
  133. self.assertEqual(label2.name, result.label.name)
  134. self.assertEqual("user", result.origin)
  135. def test_unset_label_of_bounding_box(self):
  136. label, is_new = self.project.create_label(name="label", reference="some_label")
  137. self.assertTrue(is_new)
  138. result = self.file.create_result("pipeline", "bounding-box", label=label)
  139. self.assertEqual(label.id, result.label_id)
  140. self.assertEqual(label.name, result.label.name)
  141. self.assertEqual("pipeline", result.origin)
  142. url = url_for("edit_result_label", result_id=result.id)
  143. self.post(url, json=dict(label=None))
  144. self.assertEqual(None, result.label_id)
  145. self.assertEqual("user", result.origin)
  146. def test_confirm_result(self):
  147. self.post(url_for("confirm_result", result_id=4242), status_code=404)
  148. label, is_new = self.project.create_label(name="label", reference="some_label1")
  149. for result_type in ["labeled-image", "bounding-box"]:
  150. result = self.file.create_result("pipeline", result_type, label=label)
  151. url = url_for("confirm_result", result_id=result.id)
  152. for data in [None, dict(), dict(confirm=False)]:
  153. self.post(url, status_code=400, json=data)
  154. self.assertEqual("pipeline", result.origin)
  155. self.post(url, json=dict(confirm=True))
  156. self.assertEqual("user", result.origin)
  157. class ResultRemovalTests(_BaseResultTests):
  158. def test_remove_result(self):
  159. self.post(url_for("remove_result", result_id=4242), status_code=404)
  160. label, is_new = self.project.create_label(name="label", reference="some_label1")
  161. self.assertEqual(0, Result.query.count())
  162. for result_type in ["labeled-image", "bounding-box"]:
  163. result = self.file.create_result("pipeline", result_type, label=label)
  164. self.assertEqual(1, Result.query.count())
  165. self.assertEqual(1, self.file.results.count())
  166. url = url_for("remove_result", result_id=result.id)
  167. for data in [None, dict(), dict(remove=False)]:
  168. self.post(url, status_code=400, json=data)
  169. self.assertEqual(1, Result.query.count())
  170. self.assertEqual(1, self.file.results.count())
  171. self.post(url, json=dict(remove=True))
  172. self.assertEqual(0, Result.query.count())
  173. self.assertEqual(0, self.file.results.count())
  174. self.assertEqual(0, Result.query.count())
  175. def test_reset_file_results(self):
  176. self.post(url_for("reset_results", file_id=4242), status_code=404)
  177. label, is_new = self.project.create_label(name="label", reference="some_label1")
  178. self.assertEqual(0, Result.query.count())
  179. n = 5
  180. for result_type in ["labeled-image", "bounding-box"]:
  181. for i in range(n):
  182. self.file.create_result("pipeline", result_type, label=label)
  183. url = url_for("reset_results", file_id=self.file.id)
  184. for data in [None, dict(), dict(reset=False)]:
  185. self.post(url, status_code=400, json=data)
  186. self.assertEqual(n, Result.query.count())
  187. self.assertEqual(n, self.file.results.count())
  188. self.post(url, json=dict(reset=True))
  189. self.assertEqual(0, Result.query.count())
  190. self.assertEqual(0, self.file.results.count())