part.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import numpy as np
  2. from contextlib import contextmanager
  3. from matplotlib import pyplot as plt
  4. from matplotlib.patches import Rectangle
  5. from abc import ABC, abstractproperty
  6. from nabirds import utils
  7. class BaseParts(ABC):
  8. def __getitem__(self, i):
  9. return self._parts[i]
  10. def __repr__(self):
  11. return repr(np.stack([p.as_annotation for p in self._parts]))
  12. @property
  13. def selected(self):
  14. return np.array([p.is_visible for p in self._parts], dtype=bool)
  15. @property
  16. def selected_idxs(self):
  17. return np.where(self.selected)[0]
  18. def select(self, idxs):
  19. if isinstance(idxs, np.ndarray) and idxs.dtype == bool:
  20. # a mask is present, so convert it to indeces
  21. idxs = np.where(idxs)[0]
  22. for p in self._parts:
  23. p.is_visible = p._id in idxs
  24. def hide_outside_bb(self, *bounding_box):
  25. for p in self._parts:
  26. p.hide_if_outside(*bounding_box)
  27. def invert_selection(self):
  28. self.select(np.logical_not(self.selected))
  29. def offset(self, dx, dy):
  30. for p in self._parts:
  31. p.x += dx
  32. p.y += dy
  33. def visible_locs(self):
  34. vis = [(p._id, p.xy) for p in self._parts if p.is_visible]
  35. idxs, xy = zip(*vis)
  36. return np.array(idxs), np.array(xy).T
  37. def visible_crops(self, *args, **kwargs):
  38. return np.array([p.crop(*args, **kwargs) for p in self._parts])
  39. def plot(self, cmap=plt.cm.jet, **kwargs):
  40. for i, p in enumerate(self._parts):
  41. p.plot(color=cmap(i/len(self._parts)), **kwargs)
  42. def reveal(self, im, ratio, *args, **kwargs):
  43. res = np.zeros_like(im)
  44. for part in self._parts:
  45. if not part.is_visible: continue
  46. x, y, crop = part.reveal(im, ratio=ratio, *args, **kwargs)
  47. h, w, _ = crop.shape
  48. res[y:y+h, x:x+w] = crop
  49. return res
  50. class Parts(BaseParts):
  51. def __init__(self, image, part_annotations, rescale_size):
  52. super(Parts, self).__init__()
  53. self._parts = [BasePart.new(image, a, rescale_size) for a in part_annotations]
  54. class UniformParts(BaseParts):
  55. def __init__(self, image, ratio):
  56. super(UniformParts, self).__init__()
  57. self._parts = list(self.generate_parts(image, ratio))
  58. def generate_parts(self, im, ratio, round_op=np.floor):
  59. h, w, c = utils.dimensions(im)
  60. part_w = round_op(w * ratio).astype(np.int32)
  61. part_h = round_op(h * ratio).astype(np.int32)
  62. n, m = w // part_w, h // part_h
  63. # fit best possible part_w and part_h
  64. part_w = int(w / n)
  65. part_h = int(h / m)
  66. for i in range(n*m):
  67. row, col = np.unravel_index(i, (n, m))
  68. x, y = col * part_w, row * part_h
  69. yield BBoxPart(im, [i, x, y, part_w, part_h])
  70. class BasePart(ABC):
  71. def __repr__(self):
  72. return repr(self.as_annotation)
  73. @staticmethod
  74. def new(image, annotation, rescale_size=-1):
  75. if len(annotation) == 4:
  76. return LocationPart(image, annotation, rescale_size)
  77. elif len(annotation) == 5:
  78. return BBoxPart(image, annotation, rescale_size)
  79. else:
  80. raise ValueError("Unknown part annotation format: {}".format(annotation))
  81. def rescale(self, image, annotation, rescale_size):
  82. if rescale_size is not None and rescale_size > 0:
  83. h, w, c = utils.dimensions(image)
  84. scale = np.array([w, h]) / rescale_size
  85. xy = annotation[1:3]
  86. xy = xy * scale
  87. annotation[1:3] = xy
  88. return annotation
  89. @property
  90. def is_visible(self):
  91. return self._is_visible
  92. @is_visible.setter
  93. def is_visible(self, value):
  94. self._is_visible = bool(value)
  95. @property
  96. def xy(self):
  97. return np.array([self.x, self.y])
  98. def crop(self, im, w, h, padding_mode="edge", is_location=True):
  99. if not self.is_visible:
  100. _, _, c = utils.dimensions(im)
  101. return np.zeros((h, w, c), dtype=np.uint8)
  102. x, y = self.xy
  103. pad_h, pad_w = h // 2, w // 2
  104. padded_im = np.pad(im, [(pad_h, pad_h), (pad_w, pad_w), [0,0]], mode=padding_mode)
  105. x0, y0 = x + pad_w, y + pad_h
  106. if is_location:
  107. x0, y0 = x0 - w // 2, y0 - h // 2
  108. return padded_im[y0:y0+h, x0:x0+w]
  109. @abstractproperty
  110. def middle(self):
  111. raise NotImplementedError
  112. def plot(self, **kwargs):
  113. return
  114. def hide_if_outside(self, x, y, w, h):
  115. mid_x, mid_y = self.middle
  116. self.is_visible = ((x <= mid_x <= x+w) and (y <= mid_y <= y+h))
  117. class LocationPart(BasePart):
  118. DEFAULT_RATIO = np.sqrt(49 / 400) # 0.35
  119. def __init__(self, image, annotation, rescale_size=None):
  120. super(LocationPart, self).__init__()
  121. annotation = self.rescale(image, annotation, rescale_size)
  122. # here x,y are the center of the part
  123. self._id, self.x, self.y, self.is_visible = annotation
  124. self._ratio = LocationPart.DEFAULT_RATIO
  125. def as_annotation(self):
  126. return np.array([self._id, self.x, self.y, self.is_visible])
  127. def crop(self, image, ratio=None, padding_mode="edge", *args, **kwargs):
  128. ratio = ratio or self._ratio
  129. _h, _w, c = utils.dimensions(image)
  130. w, h = int(_w * ratio), int(_h * ratio)
  131. return super(LocationPart, self).crop(image, w, h,
  132. padding_mode, is_location=True)
  133. def reveal(self, im, ratio=None, *args, **kwargs):
  134. _h, _w, c = utils.dimensions(im)
  135. w, h = int(_w * ratio), int(_h * ratio)
  136. x,y = self.xy
  137. x, y = max(x - w // 2, 0), max(y - h // 2, 0)
  138. return x, y, im[y:y+h, x:x+w]
  139. def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
  140. if not self.is_visible: return
  141. x, y = self.xy
  142. _h, _w, c = utils.dimensions(im)
  143. w, h = int(_w * ratio), int(_h * ratio)
  144. ax.add_patch(Rectangle(
  145. (x-w//2, y-h//2), w, h,
  146. fill=fill, linestyle=linestyle,
  147. **kwargs
  148. ))
  149. ax.scatter(*self.middle, marker="x", color="white", alpha=0.8)
  150. @property
  151. def middle(self):
  152. return np.array([self.x, self.y])
  153. class BBoxPart(BasePart):
  154. def __init__(self, image, annotation, rescale_size=None):
  155. super(BBoxPart, self).__init__()
  156. annotation = self.rescale(image, annotation, rescale_size)
  157. # here x,y are top left corner of the part
  158. self._id, self.x, self.y, self.w, self.h = annotation
  159. self.is_visible = True
  160. @property
  161. def as_annotation(self):
  162. return np.array([self._id, self.x, self.y, self.w, self.h])
  163. def rescale(self, image, annotation, rescale_size):
  164. if rescale_size is not None and rescale_size > 0:
  165. annotation = super(BBoxPart, self).rescale(image, annotation, rescale_size)
  166. h, w, c = utils.dimensions(image)
  167. scale = np.array([w, h]) / rescale_size
  168. wh = annotation[3:5]
  169. wh = wh * scale
  170. annotation[3:5] = wh
  171. return annotation
  172. @property
  173. def middle(self):
  174. return np.array([self.x + self.w // 2, self.y + self.h // 2])
  175. def crop(self, image, padding_mode="edge", *args, **kwargs):
  176. return super(BBoxPart, self).crop(image, self.w, self.h,
  177. padding_mode, is_location=False)
  178. def reveal(self, im, ratio, *args, **kwargs):
  179. _h, _w, c = utils.dimensions(im)
  180. x,y = self.xy
  181. return x, y, im[y:y+self.h, x:x+self.w]
  182. def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
  183. ax.add_patch(Rectangle(
  184. (self.x, self.y), self.w, self.h,
  185. fill=fill, linestyle=linestyle,
  186. **kwargs
  187. ))
  188. ax.scatter(*self.middle, marker="x", color="white", alpha=0.8)