image.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from imageio import imread
  2. from PIL import Image
  3. from os.path import isfile
  4. import copy
  5. import numpy as np
  6. from . import utils
  7. from .part import Parts
  8. def should_have_parts(func):
  9. def inner(self, *args, **kwargs):
  10. assert self.has_parts, "parts are not present!"
  11. return func(self, *args, **kwargs)
  12. return inner
  13. class ImageWrapper(object):
  14. @staticmethod
  15. def read_image(im_path, mode="RGB"):
  16. # im = imread(im_path, pilmode=mode)
  17. im = Image.open(im_path, mode="r")
  18. return im
  19. def __init__(self, im_path, label, parts=None, mode="RGB", part_rescale_size=None):
  20. self.mode = mode
  21. self.im = im_path
  22. self._im_array = None
  23. self.label = label
  24. self.parts = Parts(self.im, parts, part_rescale_size)
  25. self.parent = None
  26. self._feature = None
  27. def __del__(self):
  28. if isinstance(self._im, Image.Image):
  29. if self._im is not None and getattr(self._im, "fp", None) is not None:
  30. self._im.close()
  31. @property
  32. def im_array(self):
  33. if self._im_array is None:
  34. if isinstance(self._im, Image.Image):
  35. _im = self._im.convert(self.mode)
  36. self._im_array = utils.asarray(_im)
  37. elif isinstance(self._im, np.ndarray):
  38. if self.mode == "RGB" and self._im.ndim == 2:
  39. self._im_array = np.stack((self._im,) * 3, axis=-1)
  40. elif self._im.ndim in (3, 4):
  41. self._im_array = self._im
  42. else:
  43. raise ValueError()
  44. else:
  45. raise ValueError()
  46. return self._im_array
  47. @property
  48. def im(self):
  49. if isinstance(self._im, Image.Image) and self._im.mode != self.mode:
  50. self._im = self._im.convert(self.mode)
  51. return self._im
  52. @im.setter
  53. def im(self, value):
  54. if isinstance(value, str):
  55. assert isfile(value), "Image \"{}\" does not exist!".format(value)
  56. self._im = ImageWrapper.read_image(value, mode=self.mode)
  57. self._im_path = value
  58. else:
  59. self._im = value
  60. def as_tuple(self):
  61. return self.im_array, self.parts, self.label
  62. def copy(self):
  63. new = copy.copy(self)
  64. new.parent = self
  65. deepcopies = [
  66. "_feature",
  67. "parts",
  68. ]
  69. for attr_name in deepcopies:
  70. attr_copy = copy.deepcopy(getattr(self, attr_name))
  71. setattr(new, attr_name, attr_copy)
  72. return new
  73. @property
  74. def feature(self):
  75. return self._feature
  76. @feature.setter
  77. def feature(self, im_feature):
  78. self._feature = im_feature
  79. def crop(self, x, y, w, h):
  80. result = self.copy()
  81. # result.im = self.im[y:y+h, x:x+w]
  82. result.im = self.im.crop((x, y, x+w, y+h))
  83. if self.has_parts:
  84. result.parts.offset(-x, -y)
  85. return result
  86. @should_have_parts
  87. def hide_parts_outside_bb(self, x, y, w, h):
  88. idxs, (xs,ys) = self.visible_part_locs()
  89. f = np.logical_and
  90. mask = f(f(x <= xs, xs <= x+w), f(y <= ys, ys <= y+h))
  91. result = self.copy()
  92. result.parts.select(idxs[mask])
  93. return result
  94. def uniform_parts(self, ratio):
  95. result = self.copy()
  96. raise NotImplementedError("FIX ME!")
  97. result.parts = utils.uniform_parts(self.im, ratio=ratio)
  98. return result
  99. @should_have_parts
  100. def select_parts(self, idxs):
  101. result = self.copy()
  102. result.parts.select(idxs)
  103. return result
  104. @should_have_parts
  105. def select_random_parts(self, rnd, n_parts):
  106. idxs, xy = self.visible_part_locs()
  107. rnd_idxs = utils.random_idxs(idxs, rnd=rnd, n_parts=n_parts)
  108. return self.select_parts(rnd_idxs)
  109. @should_have_parts
  110. def visible_crops(self, ratio):
  111. return self.parts.visible_crops(self.im, ratio=ratio)
  112. @should_have_parts
  113. def visible_part_locs(self):
  114. return self.parts.visible_locs()
  115. @should_have_parts
  116. def reveal_visible(self, ratio):
  117. _, xy = self.visible_part_locs()
  118. result = self.copy()
  119. result.im = utils.reveal_parts(self.im, xy, ratio=ratio)
  120. return result
  121. @should_have_parts
  122. def part_crops(self, ratio):
  123. crops = self.visible_crops(ratio)
  124. idxs, _ = self.visible_part_locs()
  125. result = self.copy()
  126. result.im = crops[idxs]
  127. return result
  128. @property
  129. def has_parts(self):
  130. return self.parts is not None