crop_tests.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import unittest
  2. import numpy as np
  3. from skimage.transform import resize
  4. from cvdatasets.dataset.part.surrogate import SurrogateType
  5. from cvdatasets.dataset.part.base import BasePart
  6. class PartCropTest(unittest.TestCase):
  7. def setUp(self):
  8. self.im = np.random.randn(300, 300, 3).astype(np.uint8)
  9. def _check_crop(self, cropped_im, _should):
  10. self.assertIsNotNone(cropped_im,
  11. "method crop should return something!")
  12. self.assertIsInstance(cropped_im, type(self.im),
  13. "result should have the same type as the input image")
  14. crop_h, crop_w, _ = cropped_im.shape
  15. h, w, _ = _should.shape
  16. self.assertEqual(crop_h, h, "incorrect crop height")
  17. self.assertEqual(crop_w, w, "incorrect crop width")
  18. self.assertTrue((cropped_im == _should).all(),
  19. "crop was incorret")
  20. def test_bbox_part_crop(self):
  21. _id, x, y, w, h = annotation = (0, 20, 20, 100, 100)
  22. part = BasePart.new(self.im, annotation)
  23. cropped_im = part.crop(self.im)
  24. _should = self.im[y:y+h, x:x+w]
  25. self._check_crop(cropped_im, _should)
  26. def test_location_part_crop(self):
  27. _id, center_x, center_y, _vis = annotation = (0, 50, 50, 1)
  28. part = BasePart.new(self.im, annotation)
  29. h, w, c = self.im.shape
  30. for ratio in np.linspace(0.1, 0.3, num=9):
  31. _h, _w = int(h * ratio), int(w * ratio)
  32. cropped_im = part.crop(self.im, ratio=ratio)
  33. x, y = center_x - _h // 2, center_y - _w // 2
  34. _should = self.im[y : y + _h, x : x + _w]
  35. self._check_crop(cropped_im, _should)
  36. def test_non_visible_location_crop(self):
  37. _id, center_x, center_y, _vis = annotation = (0, 50, 50, 0)
  38. def _blank(im, w, h):
  39. return np.zeros((h, w, 3), dtype=im.dtype)
  40. def _middle(im, w, h):
  41. im_h, im_w, c = im.shape
  42. middle_x, middle_y = im_w // 2, im_h // 2
  43. x0 = middle_x - w // 2
  44. y0 = middle_y - h // 2
  45. return im[y0: y0+h, x0: x0+w]
  46. def _image(im, w, h):
  47. return resize(im, (h, w),
  48. mode="constant",
  49. anti_aliasing=True,
  50. preserve_range=True).astype(np.uint8)
  51. shoulds = [
  52. (SurrogateType.BLANK, _blank),
  53. (SurrogateType.MIDDLE, _middle),
  54. (SurrogateType.IMAGE, _image),
  55. ]
  56. for surr_type, should in shoulds:
  57. bbox = BasePart.new(self.im, annotation, surrogate_type=surr_type)
  58. h, w, c = self.im.shape
  59. for ratio in np.linspace(0.1, 0.3, num=9):
  60. _h, _w = int(h * ratio), int(w * ratio)
  61. cropped_im = bbox.crop(self.im, ratio=ratio)
  62. _should = should(self.im, _w, _h)
  63. self._check_crop(cropped_im, _should)