file_tests.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import cv2
  2. import io
  3. import numpy as np
  4. import os
  5. import uuid
  6. from PIL import Image
  7. from flask import url_for
  8. from pathlib import Path
  9. from pycs.database.File import File
  10. from tests.base import pаtch_tpool_execute
  11. from tests.client.label_tests import _BaseLabelTests
  12. class _BaseFileTests(_BaseLabelTests):
  13. def setupModels(self):
  14. super().setupModels()
  15. root = Path(self.project.root_folder)
  16. data_root = Path(self.project.data_folder)
  17. for folder in [data_root, root / "temp"]:
  18. folder.mkdir(exist_ok=True, parents=True)
  19. class FileCreationTests(_BaseFileTests):
  20. @pаtch_tpool_execute
  21. def test_file_upload_project_with_external_data(self, mocked_execute=None):
  22. file_content = b"some content+1"
  23. url = url_for("upload_file", project_id=self.project.id)
  24. self.assertEqual(0, File.query.count())
  25. self.project.external_data = True
  26. self.project.commit()
  27. self.post(url,
  28. data=dict(file=(io.BytesIO(file_content), "image.jpg")),
  29. content_type="multipart/form-data",
  30. status_code=400,
  31. )
  32. self.assertEqual(0, File.query.count())
  33. @pаtch_tpool_execute
  34. def test_file_upload(self, mocked_execute=None):
  35. url = url_for("upload_file", project_id=4242)
  36. self.post(url, data=dict(), status_code=404)
  37. file_content = b"some content+1"
  38. url = url_for("upload_file", project_id=self.project.id)
  39. self.assertEqual(0, File.query.count())
  40. self.post(url, data=dict(),
  41. status_code=400)
  42. self.assertEqual(0, File.query.count())
  43. self.post(url,
  44. data=dict(file=(io.BytesIO(file_content), "image.jpg")),
  45. content_type="multipart/form-data",
  46. )
  47. self.assertEqual(1, File.query.count())
  48. # this does not work, if we do not set the CONTENT_LENGTH by ourself
  49. # file = File.query.first()
  50. # self.assertEqual(len(file_content), file.size)
  51. class FileDeletionTests(_BaseFileTests):
  52. def test_file_removal(self):
  53. file_uuid = str(uuid.uuid1())
  54. file, is_new = self.project.add_file(
  55. uuid=file_uuid,
  56. file_type="image",
  57. name=f"name",
  58. filename=f"image",
  59. extension=".jpg",
  60. size=32*1024,
  61. )
  62. self.assertTrue(is_new)
  63. self.assertEqual(1, self.project.files.count())
  64. with open(file.absolute_path, "w"):
  65. pass
  66. self.assertTrue(os.path.exists(file.absolute_path))
  67. url = url_for("remove_file", file_id=file.id)
  68. self.post(url, json=dict(), status_code=400)
  69. self.post(url, json=dict(remove=False), status_code=400)
  70. self.post(url, json=dict(remove=True))
  71. self.assertEqual(0, self.project.files.count())
  72. self.assertFalse(os.path.exists(file.absolute_path))
  73. url = url_for("remove_file", file_id=4242)
  74. self.post(url, json=dict(remove=True), status_code=404)
  75. def test_file_removal_from_project_with_external_data(self):
  76. file_uuid = str(uuid.uuid1())
  77. file, is_new = self.project.add_file(
  78. uuid=file_uuid,
  79. file_type="image",
  80. name=f"name",
  81. filename=f"image",
  82. extension=".jpg",
  83. size=32*1024,
  84. )
  85. self.assertTrue(is_new)
  86. with open(file.absolute_path, "w"):
  87. pass
  88. self.project.external_data = True
  89. self.assertTrue(os.path.exists(file.absolute_path))
  90. url = url_for("remove_file", file_id=file.id)
  91. self.assertEqual(1, self.project.files.count())
  92. self.post(url, json=dict(remove=True), status_code=400)
  93. self.assertEqual(1, self.project.files.count())
  94. class FileGettingTests(_BaseFileTests):
  95. def test_get_file_getting(self):
  96. file_uuid = str(uuid.uuid1())
  97. file, is_new = self.project.add_file(
  98. uuid=file_uuid,
  99. file_type="image",
  100. name=f"name",
  101. filename=f"image",
  102. extension=".jpg",
  103. size=32*1024,
  104. )
  105. self.assertTrue(is_new)
  106. self.assertEqual(1, self.project.files.count())
  107. url = url_for("get_file", file_id=file.id)
  108. # without an actual file, this GET request returns 404
  109. self.get(url, status_code=404)
  110. content = b"some text"
  111. with open(file.absolute_path, "wb") as f:
  112. f.write(content)
  113. response = self.get(url)
  114. self.assertFalse(response.is_json)
  115. self.assertEqual(content, response.data)
  116. def test_get_prev_next_file(self):
  117. for i in range(1, 6):
  118. file_uuid = str(uuid.uuid1())
  119. file, is_new = self.project.add_file(
  120. uuid=file_uuid,
  121. file_type="image",
  122. name=f"name_{i}",
  123. filename=f"image_{i}",
  124. extension=".jpg",
  125. size=32*1024,
  126. )
  127. self.assertTrue(is_new)
  128. with open(file.absolute_path, "wb") as f:
  129. f.write(b"some content")
  130. self.assertEqual(5, self.project.files.count())
  131. files = self.project.files.all()
  132. url = url_for("get_previous_and_next_file", file_id=4542)
  133. self.get(url, status_code=404)
  134. for i, file in enumerate(files):
  135. p_file, n_file = None, None
  136. if i != 0:
  137. p_file = files[i-1].serialize()
  138. if i < len(files)-1:
  139. n_file = files[i+1].serialize()
  140. url = url_for("get_previous_and_next_file", file_id=file.id)
  141. response = self.get(url)
  142. self.assertTrue(response.is_json)
  143. content_should = dict(
  144. next=n_file,
  145. nextInCollection=n_file,
  146. previous=p_file,
  147. previousInCollection=p_file,
  148. )
  149. self.assertDictEqual(content_should, response.json)
  150. files[1].delete()
  151. file = files[2]
  152. p_file, n_file = files[0], files[3]
  153. url = url_for("get_previous_and_next_file", file_id=file.id)
  154. response = self.get(url)
  155. self.assertTrue(response.is_json)
  156. content_should = dict(
  157. next=n_file.serialize(),
  158. nextInCollection=n_file.serialize(),
  159. previous=p_file.serialize(),
  160. previousInCollection=p_file.serialize(),
  161. )
  162. self.assertDictEqual(content_should, response.json)
  163. files[3].delete()
  164. file = files[2]
  165. p_file, n_file = files[0], files[4]
  166. url = url_for("get_previous_and_next_file", file_id=file.id)
  167. response = self.get(url)
  168. self.assertTrue(response.is_json)
  169. content_should = dict(
  170. next=n_file.serialize(),
  171. nextInCollection=n_file.serialize(),
  172. previous=p_file.serialize(),
  173. previousInCollection=p_file.serialize(),
  174. )
  175. self.assertDictEqual(content_should, response.json)
  176. class FileResizingTests(_BaseFileTests):
  177. def _add_image(self, shape, file: File):
  178. image = np.random.randint(0, 256, shape).astype(np.uint8)
  179. im = Image.fromarray(image)
  180. im.save(file.absolute_path)
  181. self.assertTrue(os.path.exists(file.absolute_path))
  182. return image
  183. def _compare_images(self, im0, im1, threshold=1e-3):
  184. im0, im1 = im0 / 255, im1 / 255
  185. mse = np.mean((im0 - im1)**2)
  186. self.assertLess(mse, threshold)
  187. @pаtch_tpool_execute
  188. def test_resize_image(self, mocked_execute):
  189. self.get(url_for("get_resized_file", file_id=4242, resolution=300), status_code=404)
  190. file_uuid = str(uuid.uuid1())
  191. file, is_new = self.project.add_file(
  192. uuid=file_uuid,
  193. file_type="image",
  194. name=f"name",
  195. filename=f"image",
  196. extension=".png",
  197. size=32*1024,
  198. )
  199. self.assertTrue(is_new)
  200. image = self._add_image((300, 300), file)
  201. for upscale in [300, 1200, 500, 320]:
  202. url = url_for("get_resized_file", file_id=file.id, resolution=upscale)
  203. response = self.get(url)
  204. self.assertFalse(response.is_json)
  205. returned_im = np.asarray(Image.open(io.BytesIO(response.data)))
  206. self.assertEqual(image.shape, returned_im.shape)
  207. self._compare_images(image, returned_im)
  208. # repeat the last scale two times to get the cached resized image
  209. for downscale in [299, 200, 150, 32, 32]:
  210. sm_image = np.asarray(Image.fromarray(image).resize((downscale, downscale)))
  211. url = url_for("get_resized_file", file_id=file.id, resolution=downscale)
  212. response = self.get(url)
  213. self.assertFalse(response.is_json)
  214. returned_im = np.asarray(Image.open(io.BytesIO(response.data)))
  215. self.assertEqual(sm_image.shape, returned_im.shape)
  216. self._compare_images(sm_image, returned_im)
  217. del sm_image
  218. save = file.path
  219. file.path = "/some/nonexisting/path"
  220. file.commit()
  221. url = url_for("get_resized_file", file_id=file.id, resolution=upscale)
  222. response = self.get(url, status_code=404)
  223. file.path = save
  224. file.commit()