6
0

file_tests.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import cv2
  2. import io
  3. import numpy as np
  4. import os
  5. import uuid
  6. from PIL import Image
  7. from flask import url_for
  8. from pathlib import Path
  9. from pycs.database.File import File
  10. from pycs.util.FileOperations import BoundingBox
  11. from tests.base import pаtch_tpool_execute
  12. from tests.client.label_tests import _BaseLabelTests
  13. class _BaseFileTests(_BaseLabelTests):
  14. def setupModels(self):
  15. super().setupModels()
  16. root = Path(self.project.root_folder)
  17. data_root = Path(self.project.data_folder)
  18. for folder in [data_root, root / "temp"]:
  19. folder.mkdir(exist_ok=True, parents=True)
  20. class FileCreationTests(_BaseFileTests):
  21. @pаtch_tpool_execute
  22. def test_file_upload_project_with_external_data(self, mocked_execute=None):
  23. file_content = b"some content+1"
  24. url = url_for("upload_file", project_id=self.project.id)
  25. self.assertEqual(0, File.query.count())
  26. self.project.external_data = True
  27. self.project.commit()
  28. self.post(url,
  29. data=dict(file=(io.BytesIO(file_content), "image.jpg")),
  30. content_type="multipart/form-data",
  31. status_code=400,
  32. )
  33. self.assertEqual(0, File.query.count())
  34. @pаtch_tpool_execute
  35. def test_file_upload(self, mocked_execute=None):
  36. url = url_for("upload_file", project_id=4242)
  37. self.post(url, data=dict(), status_code=404)
  38. file_content = b"some content+1"
  39. url = url_for("upload_file", project_id=self.project.id)
  40. self.assertEqual(0, File.query.count())
  41. self.post(url, data=dict(),
  42. status_code=400)
  43. self.assertEqual(0, File.query.count())
  44. self.post(url,
  45. data=dict(file=(io.BytesIO(file_content), "image.jpg")),
  46. content_type="multipart/form-data",
  47. )
  48. self.assertEqual(1, File.query.count())
  49. # this does not work, if we do not set the CONTENT_LENGTH by ourself
  50. # file = File.query.first()
  51. # self.assertEqual(len(file_content), file.size)
  52. class FileDeletionTests(_BaseFileTests):
  53. def test_file_removal(self):
  54. file_uuid = str(uuid.uuid1())
  55. file, is_new = self.project.add_file(
  56. uuid=file_uuid,
  57. file_type="image",
  58. name=f"name",
  59. filename=f"image",
  60. extension=".jpg",
  61. size=32*1024,
  62. )
  63. self.assertTrue(is_new)
  64. self.assertEqual(1, self.project.files.count())
  65. with open(file.absolute_path, "w"):
  66. pass
  67. self.assertTrue(os.path.exists(file.absolute_path))
  68. url = url_for("remove_file", file_id=file.id)
  69. self.post(url, json=dict(), status_code=400)
  70. self.post(url, json=dict(remove=False), status_code=400)
  71. self.post(url, json=dict(remove=True))
  72. self.assertEqual(0, self.project.files.count())
  73. self.assertFalse(os.path.exists(file.absolute_path))
  74. url = url_for("remove_file", file_id=4242)
  75. self.post(url, json=dict(remove=True), status_code=404)
  76. def test_file_removal_from_project_with_external_data(self):
  77. file_uuid = str(uuid.uuid1())
  78. file, is_new = self.project.add_file(
  79. uuid=file_uuid,
  80. file_type="image",
  81. name=f"name",
  82. filename=f"image",
  83. extension=".jpg",
  84. size=32*1024,
  85. )
  86. self.assertTrue(is_new)
  87. with open(file.absolute_path, "w"):
  88. pass
  89. self.project.external_data = True
  90. self.assertTrue(os.path.exists(file.absolute_path))
  91. url = url_for("remove_file", file_id=file.id)
  92. self.assertEqual(1, self.project.files.count())
  93. self.post(url, json=dict(remove=True), status_code=400)
  94. self.assertEqual(1, self.project.files.count())
  95. class FileGettingTests(_BaseFileTests):
  96. def test_get_file_getting(self):
  97. file_uuid = str(uuid.uuid1())
  98. file, is_new = self.project.add_file(
  99. uuid=file_uuid,
  100. file_type="image",
  101. name=f"name",
  102. filename=f"image",
  103. extension=".jpg",
  104. size=32*1024,
  105. )
  106. self.assertTrue(is_new)
  107. self.assertEqual(1, self.project.files.count())
  108. url = url_for("get_file", file_id=file.id)
  109. # without an actual file, this GET request returns 404
  110. self.get(url, status_code=404)
  111. content = b"some text"
  112. with open(file.absolute_path, "wb") as f:
  113. f.write(content)
  114. response = self.get(url)
  115. self.assertFalse(response.is_json)
  116. self.assertEqual(content, response.data)
  117. response.close()
  118. def test_get_prev_next_file(self):
  119. for i in range(1, 6):
  120. file_uuid = str(uuid.uuid1())
  121. file, is_new = self.project.add_file(
  122. uuid=file_uuid,
  123. file_type="image",
  124. name=f"name_{i}",
  125. filename=f"image_{i}",
  126. extension=".jpg",
  127. size=32*1024,
  128. )
  129. self.assertTrue(is_new)
  130. with open(file.absolute_path, "wb") as f:
  131. f.write(b"some content")
  132. self.assertEqual(5, self.project.files.count())
  133. files = self.project.files.all()
  134. url = url_for("get_previous_and_next_file", file_id=4542)
  135. self.get(url, status_code=404)
  136. for i, file in enumerate(files):
  137. p_file, n_file = None, None
  138. if i != 0:
  139. p_file = files[i-1].serialize()
  140. if i < len(files)-1:
  141. n_file = files[i+1].serialize()
  142. url = url_for("get_previous_and_next_file", file_id=file.id)
  143. response = self.get(url)
  144. self.assertTrue(response.is_json)
  145. content_should = dict(
  146. next=n_file,
  147. nextInCollection=n_file,
  148. previous=p_file,
  149. previousInCollection=p_file,
  150. )
  151. self.assertDictEqual(content_should, response.json)
  152. files[1].delete()
  153. file = files[2]
  154. p_file, n_file = files[0], files[3]
  155. url = url_for("get_previous_and_next_file", file_id=file.id)
  156. response = self.get(url)
  157. self.assertTrue(response.is_json)
  158. content_should = dict(
  159. next=n_file.serialize(),
  160. nextInCollection=n_file.serialize(),
  161. previous=p_file.serialize(),
  162. previousInCollection=p_file.serialize(),
  163. )
  164. self.assertDictEqual(content_should, response.json)
  165. files[3].delete()
  166. file = files[2]
  167. p_file, n_file = files[0], files[4]
  168. url = url_for("get_previous_and_next_file", file_id=file.id)
  169. response = self.get(url)
  170. self.assertTrue(response.is_json)
  171. content_should = dict(
  172. next=n_file.serialize(),
  173. nextInCollection=n_file.serialize(),
  174. previous=p_file.serialize(),
  175. previousInCollection=p_file.serialize(),
  176. )
  177. self.assertDictEqual(content_should, response.json)
  178. class FileResizingTests(_BaseFileTests):
  179. def _add_image(self, shape, file: File):
  180. image = np.random.randint(0, 256, shape).astype(np.uint8)
  181. im = Image.fromarray(image)
  182. im.save(file.absolute_path)
  183. self.assertTrue(os.path.exists(file.absolute_path))
  184. return image
  185. def _compare_images(self, im0, im1, threshold=1e-3):
  186. im0, im1 = im0 / 255, im1 / 255
  187. mse = np.mean((im0 - im1)**2)
  188. self.assertLess(mse, threshold)
  189. @pаtch_tpool_execute
  190. def test_resize_image(self, mocked_execute):
  191. self.get(url_for("get_resized_file", file_id=4242, resolution=300), status_code=404).close()
  192. file_uuid = str(uuid.uuid1())
  193. file, is_new = self.project.add_file(
  194. uuid=file_uuid,
  195. file_type="image",
  196. name=f"name",
  197. filename=f"image",
  198. extension=".png",
  199. size=32*1024,
  200. )
  201. self.assertTrue(is_new)
  202. image = self._add_image((300, 300), file)
  203. for upscale in [300, 1200, 500, 320]:
  204. url = url_for("get_resized_file", file_id=file.id, resolution=upscale)
  205. response = self.get(url)
  206. self.assertFalse(response.is_json)
  207. returned_im = _im_from_bytes(response.data)
  208. response.close()
  209. self.assertEqual(image.shape, returned_im.shape)
  210. self._compare_images(image, returned_im)
  211. # repeat the last scale two times to get the cached resized image
  212. for downscale in [299, 200, 150, 32, 32]:
  213. sm_image = _resize(image, downscale)
  214. url = url_for("get_resized_file", file_id=file.id, resolution=downscale)
  215. response = self.get(url)
  216. self.assertFalse(response.is_json)
  217. returned_im = _im_from_bytes(response.data)
  218. response.close()
  219. self.assertEqual(sm_image.shape, returned_im.shape)
  220. self._compare_images(sm_image, returned_im)
  221. del sm_image
  222. @pаtch_tpool_execute
  223. def test_resize_image_not_found(self, mocked_execute):
  224. file_uuid = str(uuid.uuid1())
  225. file, is_new = self.project.add_file(
  226. uuid=file_uuid,
  227. file_type="image",
  228. name=f"name",
  229. filename=f"image",
  230. extension=".png",
  231. size=32*1024,
  232. )
  233. self.assertTrue(is_new)
  234. image = self._add_image((300, 300), file)
  235. save = file.path
  236. file.path = "/some/nonexisting/path"
  237. file.commit()
  238. url = url_for("get_resized_file", file_id=file.id, resolution=300)
  239. response = self.get(url, status_code=404)
  240. file.path = save
  241. file.commit()
  242. @pаtch_tpool_execute
  243. def test_crop_image_not_found(self, mocked_execute):
  244. file_uuid = str(uuid.uuid1())
  245. file, is_new = self.project.add_file(
  246. uuid=file_uuid,
  247. file_type="image",
  248. name=f"name",
  249. filename=f"image",
  250. extension=".png",
  251. size=32*1024,
  252. )
  253. self.assertTrue(is_new)
  254. image = self._add_image((300, 300), file)
  255. save = file.path
  256. file.path = "/some/nonexisting/path"
  257. file.commit()
  258. url = url_for("get_cropped_file", file_id=file.id,
  259. resolution=300, crop_box="0x0x1x1")
  260. response = self.get(url, status_code=404)
  261. response.close()
  262. file.path = save
  263. file.commit()
  264. @pаtch_tpool_execute
  265. def test_crop_image(self, mocked_execute):
  266. file_uuid = str(uuid.uuid1())
  267. file, is_new = self.project.add_file(
  268. uuid=file_uuid,
  269. file_type="image",
  270. name=f"name",
  271. filename=f"image",
  272. extension=".png",
  273. size=32*1024,
  274. )
  275. self.assertTrue(is_new)
  276. image = self._add_image((300, 300), file)
  277. for box in [(0,0,1,1), (0,0,1/2,1/2), (1/2,1/2, 1, 1), (1/3,1/2,3/4, 1), ]:
  278. url = url_for("get_cropped_file", file_id=file.id,
  279. resolution=300, crop_box="x".join(map(str, box)))
  280. response = self.get(url)
  281. self.assertFalse(response.is_json)
  282. returned_im = _im_from_bytes(response.data)
  283. response.close()
  284. crop = _crop(image, BoundingBox(*box))
  285. self.assertEqual(crop.shape, returned_im.shape)
  286. self._compare_images(crop, returned_im)
  287. def _im_from_bytes(data: bytes) -> np.ndarray:
  288. return np.asarray(Image.open(io.BytesIO(data)))
  289. def _resize(image: np.ndarray, size: int) -> np.ndarray:
  290. return np.asarray(Image.fromarray(image).resize((size, size)))
  291. def _crop(image: np.ndarray, box: BoundingBox) -> np.ndarray:
  292. h, w, *c = image.shape
  293. x0, y0 = int(w * box.x), int(h * box.y)
  294. crop_w, crop_h = int(w * box.w), int(h * box.h)
  295. x1, y1 = x0 + crop_w, y0 + crop_h
  296. return image[y0:y1, x0:x1]