6
0

test_client.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import io
  2. import time
  3. import eventlet
  4. from test.base import BaseTestCase
  5. from pycs.database.File import File
  6. from pycs.database.Result import Result
  7. from pycs.database.Label import Label
  8. from pycs.database.Project import Project
  9. class ClientTests(BaseTestCase):
  10. def _do_request(self, request_func, *args, status_code=200, **kwargs):
  11. response = request_func(*args, follow_redirects=True, **kwargs)
  12. self.assertEqual(response.status_code, status_code, response.get_data().decode())
  13. return response
  14. def _post(self, url, status_code=200, content_type=None, json=None, data=None):
  15. return self._do_request(self.client.post,
  16. url,
  17. status_code=status_code,
  18. json=json,
  19. data=data,
  20. content_type=content_type,
  21. )
  22. def _get(self, url, status_code=200, content_type=None, json=None, data=None):
  23. return self._do_request(self.client.get,
  24. url,
  25. status_code=status_code,
  26. json=json,
  27. data=data,
  28. content_type=content_type,
  29. )
  30. def test_project_creation(self):
  31. self.assertEqual(0, Project.query.count())
  32. self.assertEqual(0, Label.query.count())
  33. self._post(
  34. "/projects",
  35. json=dict(
  36. name="some name",
  37. description="some description",
  38. model=1,
  39. label=2,
  40. external=None,
  41. )
  42. )
  43. self.assertEqual(1, Project.query.count())
  44. project = Project.query.first()
  45. self.assertIsNotNone(project)
  46. self.assertIsNotNone(project.model)
  47. self.assertIsNotNone(project.label_provider)
  48. self.wait_for_coroutines()
  49. self.assertNotEqual(0, Label.query.count())
  50. def test_adding_file_with_result(self):
  51. self._post("/projects",
  52. json=dict(
  53. name="some name",
  54. description="some description",
  55. model=1,
  56. label=2,
  57. external=None,
  58. )
  59. )
  60. self.assertEqual(1, Project.query.count())
  61. project = Project.query.first()
  62. self.wait_for_coroutines()
  63. self.assertEqual(0, File.query.count())
  64. self._post(f"/projects/{project.id}/data",
  65. data=dict(file=(io.BytesIO(b"some content"), "image.jpg")),
  66. content_type="multipart/form-data",
  67. )
  68. self.assertEqual(1, File.query.count())
  69. file = File.query.first()
  70. self.assertEqual(0, Result.query.count())
  71. self._post(f"data/{file.id}/results",
  72. json=dict(
  73. type="bounding-box",
  74. data=dict(x0=0, x1=0, y0=0, y1=0),
  75. label=2,
  76. )
  77. )
  78. self.assertEqual(1, Result.query.count())
  79. def test_cascade_after_project_removal(self):
  80. self.assertEqual(0, File.query.count())
  81. self.assertEqual(0, Result.query.count())
  82. self.assertEqual(0, Label.query.count())
  83. self.assertEqual(0, Project.query.count())
  84. self._post("/projects",
  85. json=dict(
  86. name="some name",
  87. description="some description",
  88. model=1,
  89. label=2,
  90. external=None,
  91. )
  92. )
  93. project = Project.query.first()
  94. project_id = project.id
  95. self.wait_for_coroutines()
  96. self._post(f"/projects/{project_id}/data",
  97. data=dict(file=(io.BytesIO(b"some content"), "image.jpg")),
  98. content_type="multipart/form-data",
  99. )
  100. file = File.query.first()
  101. file_id = file.id
  102. self.wait_for_coroutines()
  103. self._post(f"data/{file_id}/results",
  104. json=dict(
  105. type="bounding-box",
  106. data=dict(x=0, y=0, w=0, h=0),
  107. label=2,
  108. )
  109. )
  110. self.assertNotEqual(0, File.query.count())
  111. self.assertNotEqual(0, Result.query.count())
  112. self.assertNotEqual(0, Label.query.count())
  113. self.assertNotEqual(0, Project.query.count())
  114. self.wait_for_coroutines()
  115. eventlet.sleep(3)
  116. self._post(f"/projects/{project_id}/remove",
  117. json=dict(remove=True),
  118. )
  119. self.assertEqual(0, Project.query.count())
  120. self.assertEqual(0, Label.query.count())
  121. self.assertEqual(0, File.query.count())
  122. self.assertEqual(0, Result.query.count())
  123. def test_result_download(self):
  124. self._post("/projects",
  125. json=dict(
  126. name="some name",
  127. description="some description",
  128. model=1,
  129. label=2,
  130. external=None,
  131. )
  132. )
  133. project = Project.query.first()
  134. project_id = project.id
  135. self.wait_for_coroutines()
  136. self._post(f"/projects/{project_id}/data",
  137. data=dict(file=(io.BytesIO(b"some content"), "image.jpg")),
  138. content_type="multipart/form-data",
  139. )
  140. file = File.query.first()
  141. file_id = file.id
  142. self.wait_for_coroutines()
  143. self._post(f"data/{file_id}/results",
  144. json=dict(
  145. type="bounding-box",
  146. data=dict(x=0, y=0, w=0, h=0),
  147. label=2,
  148. )
  149. )
  150. self.wait_for_coroutines()
  151. eventlet.sleep(3)
  152. response = self._get(f"/projects/{project_id}/results")
  153. self.assertTrue(response.is_json)
  154. file = File.query.first()
  155. result = Result.query.first()
  156. self.assertEqual(1, len(response.json))
  157. returned_file = response.json[0]
  158. self.assertEqual(returned_file["filename"], file.filename)
  159. self.assertEqual(1, len(returned_file["results"]))
  160. returned_result = returned_file["results"][0]
  161. xywh_should = [result.data.get(attr) for attr in "xywh"]
  162. xywh_is = [returned_result.get(attr) for attr in "xywh"]
  163. self.assertListEqual(xywh_is, xywh_should)
  164. self.assertEqual(returned_result["label"]["id"], result.label.id)
  165. self.assertEqual(returned_result["label"]["name"], result.label.name)