image.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from imageio import imread
  2. from PIL import Image
  3. from os.path import isfile
  4. import copy
  5. import numpy as np
  6. from . import utils
  7. from .part import Parts
  8. def should_have_parts(func):
  9. def inner(self, *args, **kwargs):
  10. assert self.has_parts, "parts are not present!"
  11. return func(self, *args, **kwargs)
  12. return inner
  13. class ImageWrapper(object):
  14. @staticmethod
  15. def read_image(im_path, mode="RGB"):
  16. # im = imread(im_path, pilmode=mode)
  17. im = Image.open(im_path, mode="r")
  18. return im
  19. def __init__(self, im_path, label, parts=None, mode="RGB", part_rescale_size=None):
  20. self.mode = mode
  21. self.im = im_path
  22. self._im_array = None
  23. self.label = label
  24. self.parts = Parts(self.im, parts, part_rescale_size)
  25. self.parent = None
  26. self._feature = None
  27. def __del__(self):
  28. if isinstance(self._im, Image.Image):
  29. if self._im is not None and getattr(self._im, "fp", None) is not None:
  30. self._im.close()
  31. @property
  32. def im_array(self):
  33. if self._im_array is None:
  34. if isinstance(self._im, Image.Image):
  35. _im = self._im.convert(self.mode)
  36. self._im_array = utils.asarray(_im)
  37. elif isinstance(self._im, np.ndarray):
  38. if self.mode == "RGB" and self._im.ndim == 2:
  39. self._im_array = np.stack((self._im,) * 3, axis=-1)
  40. elif self._im.ndim in (3, 4):
  41. self._im_array = self._im
  42. else:
  43. raise ValueError()
  44. else:
  45. raise ValueError()
  46. return self._im_array
  47. @property
  48. def im(self):
  49. if isinstance(self._im, Image.Image) and self._im.mode != self.mode:
  50. self._im = self._im.convert(self.mode)
  51. return self._im
  52. @im.setter
  53. def im(self, value):
  54. if isinstance(value, str):
  55. assert isfile(value), "Image \"{}\" does not exist!".format(value)
  56. self._im = ImageWrapper.read_image(value, mode=self.mode)
  57. self._im_path = value
  58. else:
  59. self._im = value
  60. def as_tuple(self):
  61. return self.im_array, self.parts, self.label
  62. def copy(self):
  63. new = copy.copy(self)
  64. new.parent = self
  65. deepcopies = [
  66. "_feature",
  67. "parts",
  68. ]
  69. for attr_name in deepcopies:
  70. attr_copy = copy.deepcopy(getattr(self, attr_name))
  71. setattr(new, attr_name, attr_copy)
  72. return new
  73. @property
  74. def feature(self):
  75. return self._feature
  76. @feature.setter
  77. def feature(self, im_feature):
  78. self._feature = im_feature
  79. def crop(self, x, y, w, h):
  80. result = self.copy()
  81. # result.im = self.im[y:y+h, x:x+w]
  82. result.im = self.im.crop(x, y, x+w, y+h)
  83. if self.has_parts:
  84. result.parts[:, 1] -= x
  85. result.parts[:, 2] -= y
  86. return result
  87. @should_have_parts
  88. def hide_parts_outside_bb(self, x, y, w, h):
  89. idxs, (xs,ys) = self.visible_part_locs()
  90. f = np.logical_and
  91. mask = f(f(x <= xs, xs <= x+w), f(y <= ys, ys <= y+h))
  92. idxs = np.where(mask)
  93. import pdb; pdb.set_trace()
  94. result = self.copy()
  95. result.parts.select(idxs)
  96. return result
  97. def uniform_parts(self, ratio):
  98. result = self.copy()
  99. raise NotImplementedError("FIX ME!")
  100. result.parts = utils.uniform_parts(self.im, ratio=ratio)
  101. return result
  102. @should_have_parts
  103. def select_parts(self, idxs):
  104. result = self.copy()
  105. result.parts.select(idxs)
  106. return result
  107. @should_have_parts
  108. def select_random_parts(self, rnd, n_parts):
  109. idxs, xy = self.visible_part_locs()
  110. rnd_idxs = utils.random_idxs(idxs, rnd=rnd, n_parts=n_parts)
  111. return self.select_parts(rnd_idxs)
  112. @should_have_parts
  113. def visible_crops(self, ratio):
  114. return self.parts.visible_crops(ratio=ratio)
  115. @should_have_parts
  116. def visible_part_locs(self):
  117. return self.parts.visible_locs()
  118. @should_have_parts
  119. def reveal_visible(self, ratio):
  120. _, xy = self.visible_part_locs()
  121. result = self.copy()
  122. result.im = utils.reveal_parts(self.im, xy, ratio=ratio)
  123. return result
  124. @should_have_parts
  125. def part_crops(self, ratio):
  126. crops = self.visible_crops(ratio)
  127. idxs, _ = self.visible_part_locs()
  128. result = self.copy()
  129. result.im = crops[idxs]
  130. return result
  131. @property
  132. def has_parts(self):
  133. return self.parts is not None