file_tests.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. import cv2
  2. import io
  3. import numpy as np
  4. import os
  5. import uuid
  6. from PIL import Image
  7. from flask import url_for
  8. from pathlib import Path
  9. from pycs.database.File import File
  10. from pycs.util.FileOperations import BoundingBox
  11. from tests.base import pаtch_tpool_execute
  12. from tests.client.label_tests import _BaseLabelTests
  13. class _BaseFileTests(_BaseLabelTests):
  14. def setupModels(self):
  15. super().setupModels()
  16. root = Path(self.project.root_folder)
  17. data_root = Path(self.project.data_folder)
  18. for folder in [data_root, root / "temp"]:
  19. folder.mkdir(exist_ok=True, parents=True)
  20. def _get_dummy_image_bytes(self, size=(4000, 6000, 3)):
  21. byteImgIO = io.BytesIO()
  22. byteImg = Image.fromarray(np.zeros(size).astype(np.uint8))
  23. byteImg.save(byteImgIO, "JPEG")
  24. byteImgIO.seek(0)
  25. file_content = byteImgIO.read()
  26. return file_content
  27. def _create_dummy_image(self, file_name, size=(4000, 6000, 3)):
  28. absolute_path = os.path.join(self.project.data_folder, file_name)
  29. file_content = self._get_dummy_image_bytes(size=size)
  30. with open(absolute_path, "wb") as f:
  31. f.write(file_content)
  32. return absolute_path, file_content
  33. class FileCreationTests(_BaseFileTests):
  34. @pаtch_tpool_execute
  35. def test_file_upload_project_with_external_data(self, mocked_execute=None):
  36. file_content = b"some content+1"
  37. url = url_for("upload_file", project_id=self.project.id)
  38. self.assertEqual(0, File.query.count())
  39. self.project.external_data = True
  40. self.project.commit()
  41. self.post(url,
  42. data=dict(file=(io.BytesIO(file_content), "image.jpg")),
  43. content_type="multipart/form-data",
  44. status_code=400,
  45. )
  46. self.assertEqual(0, File.query.count())
  47. @pаtch_tpool_execute
  48. def test_file_upload(self, mocked_execute=None):
  49. url = url_for("upload_file", project_id=4242)
  50. self.post(url, data=dict(), status_code=404)
  51. # Creating a dummy image with proper dummy content.
  52. file_content = self._get_dummy_image_bytes()
  53. url = url_for("upload_file", project_id=self.project.id)
  54. self.assertEqual(0, File.query.count())
  55. self.post(url, data=dict(),
  56. status_code=400)
  57. self.assertEqual(0, File.query.count())
  58. self.post(url,
  59. data=dict(file=(io.BytesIO(file_content), "image.jpg")),
  60. content_type="multipart/form-data",
  61. )
  62. self.assertEqual(1, File.query.count())
  63. # this does not work, if we do not set the CONTENT_LENGTH by ourself
  64. # file = File.query.first()
  65. # self.assertEqual(len(file_content), file.size)
  66. class FileDeletionTests(_BaseFileTests):
  67. def test_file_removal(self):
  68. self._create_dummy_image("image.jpg")
  69. file_uuid = str(uuid.uuid1())
  70. file, is_new = self.project.add_file(
  71. uuid=file_uuid,
  72. file_type="image",
  73. name=f"name",
  74. filename=f"image",
  75. extension=".jpg",
  76. size=32*1024,
  77. )
  78. self.assertTrue(is_new)
  79. self.assertEqual(1, self.project.files.count())
  80. self.assertTrue(os.path.exists(file.absolute_path))
  81. url = url_for("remove_file", file_id=file.id)
  82. self.post(url, json=dict(), status_code=400)
  83. self.post(url, json=dict(remove=False), status_code=400)
  84. self.post(url, json=dict(remove=True))
  85. self.assertEqual(0, self.project.files.count())
  86. self.assertFalse(os.path.exists(file.absolute_path))
  87. url = url_for("remove_file", file_id=4242)
  88. self.post(url, json=dict(remove=True), status_code=404)
  89. def test_file_removal_from_project_with_external_data(self):
  90. self._create_dummy_image("image.jpg")
  91. file_uuid = str(uuid.uuid1())
  92. file, is_new = self.project.add_file(
  93. uuid=file_uuid,
  94. file_type="image",
  95. name=f"name",
  96. filename=f"image",
  97. extension=".jpg",
  98. size=32*1024,
  99. )
  100. self.assertTrue(is_new)
  101. self.project.external_data = True
  102. self.assertTrue(os.path.exists(file.absolute_path))
  103. url = url_for("remove_file", file_id=file.id)
  104. self.assertEqual(1, self.project.files.count())
  105. self.post(url, json=dict(remove=True), status_code=400)
  106. self.assertEqual(1, self.project.files.count())
  107. class FileGettingTests(_BaseFileTests):
  108. def test_get_file_getting(self):
  109. file_uuid = str(uuid.uuid1())
  110. file, is_new = self.project.add_file(
  111. uuid=file_uuid,
  112. file_type="image",
  113. name=f"name",
  114. filename=f"image",
  115. extension=".jpg",
  116. size=32*1024,
  117. )
  118. self.assertTrue(is_new)
  119. self.assertEqual(1, self.project.files.count())
  120. url = url_for("get_file", file_id=file.id)
  121. # without an actual file, this GET request returns 404
  122. self.get(url, status_code=404)
  123. _, content = self._create_dummy_image("image.jpg")
  124. response = self.get(url)
  125. self.assertFalse(response.is_json)
  126. self.assertEqual(content, response.data)
  127. response.close()
  128. def test_get_prev_next_file(self):
  129. for i in range(1, 6):
  130. file_uuid = str(uuid.uuid1())
  131. file, is_new = self.project.add_file(
  132. uuid=file_uuid,
  133. file_type="image",
  134. name=f"name_{i}",
  135. filename=f"image_{i}",
  136. extension=".jpg",
  137. size=32*1024,
  138. )
  139. self.assertTrue(is_new)
  140. with open(file.absolute_path, "wb") as f:
  141. f.write(b"some content")
  142. self.assertEqual(5, self.project.files.count())
  143. files = self.project.files.all()
  144. url = url_for("get_previous_and_next_file", file_id=4542)
  145. self.get(url, status_code=404)
  146. for i, file in enumerate(files):
  147. p_file, n_file = None, None
  148. if i != 0:
  149. p_file = files[i-1].serialize()
  150. if i < len(files)-1:
  151. n_file = files[i+1].serialize()
  152. url = url_for("get_previous_and_next_file", file_id=file.id)
  153. response = self.get(url)
  154. self.assertTrue(response.is_json)
  155. content_should = dict(
  156. next=n_file,
  157. nextInCollection=n_file,
  158. previous=p_file,
  159. previousInCollection=p_file,
  160. )
  161. self.assertDictEqual(content_should, response.json)
  162. files[1].delete()
  163. file = files[2]
  164. p_file, n_file = files[0], files[3]
  165. url = url_for("get_previous_and_next_file", file_id=file.id)
  166. response = self.get(url)
  167. self.assertTrue(response.is_json)
  168. content_should = dict(
  169. next=n_file.serialize(),
  170. nextInCollection=n_file.serialize(),
  171. previous=p_file.serialize(),
  172. previousInCollection=p_file.serialize(),
  173. )
  174. self.assertDictEqual(content_should, response.json)
  175. files[3].delete()
  176. file = files[2]
  177. p_file, n_file = files[0], files[4]
  178. url = url_for("get_previous_and_next_file", file_id=file.id)
  179. response = self.get(url)
  180. self.assertTrue(response.is_json)
  181. content_should = dict(
  182. next=n_file.serialize(),
  183. nextInCollection=n_file.serialize(),
  184. previous=p_file.serialize(),
  185. previousInCollection=p_file.serialize(),
  186. )
  187. self.assertDictEqual(content_should, response.json)
  188. class FileResizingTests(_BaseFileTests):
  189. def _add_image(self, shape, file: File):
  190. image = np.random.randint(0, 256, shape).astype(np.uint8)
  191. im = Image.fromarray(image)
  192. im.save(file.absolute_path)
  193. self.assertTrue(os.path.exists(file.absolute_path))
  194. return image
  195. def _compare_images(self, im0, im1, threshold=1e-3):
  196. im0, im1 = im0 / 255, im1 / 255
  197. mse = np.mean((im0 - im1)**2)
  198. self.assertLess(mse, threshold)
  199. @pаtch_tpool_execute
  200. def test_resize_image(self, mocked_execute):
  201. self.get(url_for("get_resized_file", file_id=4242, resolution=300), status_code=404).close()
  202. file_uuid = str(uuid.uuid1())
  203. file, is_new = self.project.add_file(
  204. uuid=file_uuid,
  205. file_type="image",
  206. name=f"name",
  207. filename=f"image",
  208. extension=".png",
  209. size=32*1024,
  210. )
  211. self.assertTrue(is_new)
  212. image = self._add_image((300, 300), file)
  213. for upscale in [300, 1200, 500, 320]:
  214. url = url_for("get_resized_file", file_id=file.id, resolution=upscale)
  215. response = self.get(url)
  216. self.assertFalse(response.is_json)
  217. returned_im = _im_from_bytes(response.data)
  218. response.close()
  219. self.assertEqual(image.shape, returned_im.shape)
  220. self._compare_images(image, returned_im)
  221. # repeat the last scale two times to get the cached resized image
  222. for downscale in [299, 200, 150, 32, 32]:
  223. sm_image = _resize(image, downscale)
  224. url = url_for("get_resized_file", file_id=file.id, resolution=downscale)
  225. response = self.get(url)
  226. self.assertFalse(response.is_json)
  227. returned_im = _im_from_bytes(response.data)
  228. response.close()
  229. self.assertEqual(sm_image.shape, returned_im.shape)
  230. self._compare_images(sm_image, returned_im)
  231. del sm_image
  232. @pаtch_tpool_execute
  233. def test_resize_image_not_found(self, mocked_execute):
  234. file_uuid = str(uuid.uuid1())
  235. file, is_new = self.project.add_file(
  236. uuid=file_uuid,
  237. file_type="image",
  238. name=f"name",
  239. filename=f"image",
  240. extension=".png",
  241. size=32*1024,
  242. )
  243. self.assertTrue(is_new)
  244. image = self._add_image((300, 300), file)
  245. save = file.path
  246. file.path = "/some/nonexisting/path"
  247. file.commit()
  248. url = url_for("get_resized_file", file_id=file.id, resolution=300)
  249. response = self.get(url, status_code=404)
  250. file.path = save
  251. file.commit()
  252. @pаtch_tpool_execute
  253. def test_crop_image_not_found(self, mocked_execute):
  254. file_uuid = str(uuid.uuid1())
  255. file, is_new = self.project.add_file(
  256. uuid=file_uuid,
  257. file_type="image",
  258. name=f"name",
  259. filename=f"image",
  260. extension=".png",
  261. size=32*1024,
  262. )
  263. self.assertTrue(is_new)
  264. image = self._add_image((300, 300), file)
  265. save = file.path
  266. file.path = "/some/nonexisting/path"
  267. file.commit()
  268. url = url_for("get_cropped_file", file_id=file.id,
  269. resolution=300, crop_box="0x0x1x1")
  270. response = self.get(url, status_code=404)
  271. response.close()
  272. file.path = save
  273. file.commit()
  274. @pаtch_tpool_execute
  275. def test_crop_image(self, mocked_execute):
  276. file_uuid = str(uuid.uuid1())
  277. file, is_new = self.project.add_file(
  278. uuid=file_uuid,
  279. file_type="image",
  280. name=f"name",
  281. filename=f"image",
  282. extension=".png",
  283. size=32*1024,
  284. )
  285. self.assertTrue(is_new)
  286. image = self._add_image((300, 300), file)
  287. for box in [(0,0,1,1), (0,0,1/2,1/2), (1/2,1/2, 1, 1), (1/3,1/2,3/4, 1), ]:
  288. url = url_for("get_cropped_file", file_id=file.id,
  289. resolution=300, crop_box="x".join(map(str, box)))
  290. response = self.get(url)
  291. self.assertFalse(response.is_json)
  292. returned_im = _im_from_bytes(response.data)
  293. response.close()
  294. crop = _crop(image, BoundingBox(*box))
  295. self.assertEqual(crop.shape, returned_im.shape)
  296. self._compare_images(crop, returned_im)
  297. def test_automatic_thumbnail_generation(self):
  298. img_size = (4000, 6000, 3)
  299. self._create_dummy_image("image.jpg", size=img_size)
  300. file_uuid = str(uuid.uuid1())
  301. file, is_new = self.project.add_file(
  302. uuid=file_uuid,
  303. file_type="image",
  304. name=f"name",
  305. filename=f"image",
  306. extension=".jpg",
  307. size=32*1024,
  308. )
  309. self.assertTrue(is_new)
  310. self.assertEqual(1, self.project.files.count())
  311. self.assertTrue(os.path.exists(file.absolute_path))
  312. temp_folder = os.path.join(self.project.root_folder, "temp")
  313. for max_width, max_height in [(200, 200), (2000, 800)]:
  314. img_path = os.path.join(temp_folder, file_uuid + "_" + str(max_width) + "_" + str(max_height) + ".jpg")
  315. self.assertTrue(os.path.exists(img_path))
  316. with Image.open(img_path) as img:
  317. width, height = img.size
  318. self.assertTrue(width == max_width or height == max_height)
  319. self.assertLessEqual(width, max_width)
  320. self.assertLessEqual(height, max_height)
  321. self.assertLessEqual(abs(img_size[1] / img_size[0] - width / height), 0.1)
  322. def _im_from_bytes(data: bytes) -> np.ndarray:
  323. return np.asarray(Image.open(io.BytesIO(data)))
  324. def _resize(image: np.ndarray, size: int) -> np.ndarray:
  325. return np.asarray(Image.fromarray(image).resize((size, size)))
  326. def _crop(image: np.ndarray, box: BoundingBox) -> np.ndarray:
  327. h, w, *c = image.shape
  328. x0, y0 = int(w * box.x), int(h * box.y)
  329. crop_w, crop_h = int(w * box.w), int(h * box.h)
  330. x1, y1 = x0 + crop_w, y0 + crop_h
  331. return image[y0:y1, x0:x1]