test_client.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import io
  2. import time
  3. import eventlet
  4. from test.base import BaseTestCase
  5. from pycs.database.File import File
  6. from pycs.database.Result import Result
  7. from pycs.database.Label import Label
  8. from pycs.database.Project import Project
  9. class ClientTests(BaseTestCase):
  10. def _post(self, url, status_code=200, content_type=None, json=None, data=None):
  11. response = self.client.post(url,
  12. json=json,
  13. data=data,
  14. follow_redirects=True,
  15. content_type=content_type,
  16. )
  17. self.assertEqual(response.status_code, 200, response.get_data().decode())
  18. return response
  19. def test_project_creation(self):
  20. self.assertEqual(0, Project.query.count())
  21. self.assertEqual(0, Label.query.count())
  22. self._post(
  23. "/projects",
  24. json=dict(
  25. name="some name",
  26. description="some description",
  27. model=1,
  28. label=2,
  29. external=None,
  30. )
  31. )
  32. self.assertEqual(1, Project.query.count())
  33. project = Project.query.first()
  34. self.assertIsNotNone(project)
  35. self.assertIsNotNone(project.model)
  36. self.assertIsNotNone(project.label_provider)
  37. self.wait_for_coroutines()
  38. self.assertNotEqual(0, Label.query.count())
  39. def test_adding_file_with_result(self):
  40. self._post("/projects",
  41. json=dict(
  42. name="some name",
  43. description="some description",
  44. model=1,
  45. label=2,
  46. external=None,
  47. )
  48. )
  49. self.assertEqual(1, Project.query.count())
  50. project = Project.query.first()
  51. self.wait_for_coroutines()
  52. self.assertEqual(0, File.query.count())
  53. self._post(f"/projects/{project.id}/data",
  54. data=dict(file=(io.BytesIO(b"some content"), "image.jpg")),
  55. content_type="multipart/form-data",
  56. )
  57. self.assertEqual(1, File.query.count())
  58. file = File.query.first()
  59. self.assertEqual(0, Result.query.count())
  60. self._post(f"data/{file.id}/results",
  61. json=dict(
  62. type="bounding-box",
  63. data=dict(x0=0, x1=0, y0=0, y1=0),
  64. label=2,
  65. )
  66. )
  67. self.assertEqual(1, Result.query.count())
  68. def test_cascade_after_project_removal(self):
  69. self.assertEqual(0, File.query.count())
  70. self.assertEqual(0, Result.query.count())
  71. self.assertEqual(0, Label.query.count())
  72. self.assertEqual(0, Project.query.count())
  73. self._post("/projects",
  74. json=dict(
  75. name="some name",
  76. description="some description",
  77. model=1,
  78. label=2,
  79. external=None,
  80. )
  81. )
  82. project = Project.query.first()
  83. project_id = project.id
  84. self.wait_for_coroutines()
  85. self._post(f"/projects/{project_id}/data",
  86. data=dict(file=(io.BytesIO(b"some content"), "image.jpg")),
  87. content_type="multipart/form-data",
  88. )
  89. file = File.query.first()
  90. file_id = file.id
  91. self.wait_for_coroutines()
  92. self._post(f"data/{file_id}/results",
  93. json=dict(
  94. type="bounding-box",
  95. data=dict(x0=0, x1=0, y0=0, y1=0),
  96. label=2,
  97. )
  98. )
  99. self.assertNotEqual(0, File.query.count())
  100. self.assertNotEqual(0, Result.query.count())
  101. self.assertNotEqual(0, Label.query.count())
  102. self.assertNotEqual(0, Project.query.count())
  103. self.wait_for_coroutines()
  104. eventlet.sleep(3)
  105. self._post(f"/projects/{project_id}/remove",
  106. json=dict(remove=True),
  107. )
  108. self.assertEqual(0, Project.query.count())
  109. self.assertEqual(0, Label.query.count())
  110. self.assertEqual(0, File.query.count())
  111. self.assertEqual(0, Result.query.count())