part.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import numpy as np
  2. from matplotlib import pyplot as plt
  3. from contextlib import contextmanager
  4. from matplotlib.patches import Rectangle
  5. from abc import ABC, abstractmethod, abstractproperty
  6. from . import utils
  7. class BaseParts(ABC):
  8. def __getitem__(self, i):
  9. return self._parts[i]
  10. @property
  11. def selected(self):
  12. return np.array([p.is_visible for p in self._parts], dtype=bool)
  13. @property
  14. def selected_idxs(self):
  15. return np.where(self.selected)[0]
  16. def select(self, idxs):
  17. if isinstance(idxs, np.ndarray) and idxs.dtype == bool:
  18. # a mask is present, so convert it to indeces
  19. idxs = np.where(idxs)[0]
  20. for p in self._parts:
  21. p.is_visible = p._id in idxs
  22. def invert_selection(self):
  23. self.select(np.logical_not(self.selected))
  24. def offset(self, dx, dy):
  25. for p in self._parts:
  26. p.x += dx
  27. p.y += dy
  28. def visible_locs(self):
  29. vis = [(p._id, p.xy) for p in self._parts if p.is_visible]
  30. idxs, xy = zip(*vis)
  31. return np.array(idxs), np.array(xy).T
  32. def visible_crops(self, *args, **kwargs):
  33. return np.array([p.crop(*args, **kwargs) for p in self._parts])
  34. def plot(self, cmap=plt.cm.jet, **kwargs):
  35. for i, p in enumerate(self._parts):
  36. p.plot(color=cmap(i/len(self._parts)), **kwargs)
  37. def reveal(self, im, ratio, *args, **kwargs):
  38. res = np.zeros_like(im)
  39. for part in self._parts:
  40. if not part.is_visible: continue
  41. x, y, crop = part.reveal(im, ratio=ratio, *args, **kwargs)
  42. h, w, _ = crop.shape
  43. res[y:y+h, x:x+w] = crop
  44. return res
  45. class Parts(BaseParts):
  46. def __init__(self, image, part_annotations, rescale_size):
  47. super(Parts, self).__init__()
  48. self._parts = [BasePart.new(image, a, rescale_size) for a in part_annotations]
  49. class UniformParts(BaseParts):
  50. def __init__(self, image, ratio):
  51. super(UniformParts, self).__init__()
  52. self._parts = list(self.generate_parts(image, ratio))
  53. def generate_parts(self, im, ratio, round_op=np.floor):
  54. h, w, c = utils.dimensions(im)
  55. part_w = round_op(w * ratio).astype(np.int32)
  56. part_h = round_op(h * ratio).astype(np.int32)
  57. n, m = w // part_w, h // part_h
  58. # fit best possible part_w and part_h
  59. part_w = int(w / n)
  60. part_h = int(h / m)
  61. for i in range(n*m):
  62. row, col = np.unravel_index(i, (n, m))
  63. x, y = col * part_w, row * part_h
  64. yield BBoxPart(im, [i, x, y, part_w, part_h])
  65. class BasePart(ABC):
  66. @staticmethod
  67. def new(image, annotation, rescale_size=-1):
  68. if len(annotation) == 4:
  69. return LocationPart(image, annotation, rescale_size)
  70. elif len(annotation) == 5:
  71. return BBoxPart(image, annotation, rescale_size)
  72. else:
  73. raise ValueError("Unknown part annotation format: {}".format(annotation))
  74. def rescale(self, image, annotation, rescale_size):
  75. if rescale_size is not None and rescale_size > 0:
  76. h, w, c = utils.dimensions(image)
  77. scale = np.array([w, h]) / rescale_size
  78. xy = annotation[1:3]
  79. xy = xy * scale
  80. annotation[1:3] = xy
  81. return annotation
  82. @property
  83. def is_visible(self):
  84. return self._is_visible
  85. @is_visible.setter
  86. def is_visible(self, value):
  87. self._is_visible = bool(value)
  88. @property
  89. def xy(self):
  90. return np.array([self.x, self.y])
  91. def crop(self, im, w, h, padding_mode="edge", is_location=True):
  92. if not self.is_visible:
  93. _, _, c = utils.dimensions(im)
  94. return np.zeros((h, w, c), dtype=np.uint8)
  95. x, y = self.xy
  96. pad_h, pad_w = h // 2, w // 2
  97. padded_im = np.pad(im, [(pad_h, pad_h), (pad_w, pad_w), [0,0]], mode=padding_mode)
  98. x0, y0 = x + pad_w, y + pad_h
  99. if is_location:
  100. x0, y0 = x0 - w // 2, y0 - h // 2
  101. return padded_im[y0:y0+h, x0:x0+w]
  102. def plot(self, **kwargs):
  103. return
  104. class LocationPart(BasePart):
  105. DEFAULT_RATIO = np.sqrt(49 / 400) # 0.35
  106. def __init__(self, image, annotation, rescale_size=None):
  107. super(LocationPart, self).__init__()
  108. annotation = self.rescale(image, annotation, rescale_size)
  109. # here x,y are the center of the part
  110. self._id, self.x, self.y, self.is_visible = annotation
  111. self._ratio = LocationPart.DEFAULT_RATIO
  112. def crop(self, image, ratio=None, padding_mode="edge", *args, **kwargs):
  113. ratio = ratio or self._ratio
  114. _h, _w, c = utils.dimensions(image)
  115. w, h = int(_w * ratio), int(_h * ratio)
  116. return super(LocationPart, self).crop(image, w, h,
  117. padding_mode, is_location=True)
  118. def reveal(self, im, ratio=None, *args, **kwargs):
  119. _h, _w, c = utils.dimensions(im)
  120. w, h = int(_w * ratio), int(_h * ratio)
  121. x,y = self.xy
  122. x, y = max(x - w // 2, 0), max(y - h // 2, 0)
  123. return x, y, im[y:y+h, x:x+w]
  124. def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
  125. if not self.is_visible: return
  126. x, y = self.xy
  127. _h, _w, c = utils.dimensions(im)
  128. w, h = int(_w * ratio), int(_h * ratio)
  129. ax.add_patch(Rectangle(
  130. (x-w//2, y-h//2), w, h,
  131. fill=fill, linestyle=linestyle,
  132. **kwargs
  133. ))
  134. class BBoxPart(BasePart):
  135. def __init__(self, image, annotation, rescale_size=None):
  136. super(BBoxPart, self).__init__()
  137. annotation = self.rescale(image, annotation, rescale_size)
  138. # here x,y are top left corner of the part
  139. self._id, self.x, self.y, self.w, self.h = annotation
  140. self.is_visible = True
  141. def rescale(self, image, annotation, rescale_size):
  142. if rescale_size is not None and rescale_size > 0:
  143. annotation = super(BBoxPart, self).rescale(image, annotation, rescale_size)
  144. h, w, c = utils.dimensions(image)
  145. scale = np.array([w, h]) / rescale_size
  146. wh = annotation[3:5]
  147. wh = wh * scale
  148. annotation[3:5] = wh
  149. return annotation
  150. def crop(self, image, padding_mode="edge", *args, **kwargs):
  151. return super(BBoxPart, self).crop(image, self.w, self.h,
  152. padding_mode, is_location=False)
  153. def reveal(self, im, ratio, *args, **kwargs):
  154. _h, _w, c = utils.dimensions(im)
  155. x,y = self.xy
  156. return x, y, im[y:y+self.h, x:x+self.w]
  157. def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
  158. ax.add_patch(Rectangle(
  159. (self.x, self.y), self.w, self.h,
  160. fill=fill, linestyle=linestyle,
  161. **kwargs
  162. ))