base.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import numpy as np
  2. from matplotlib import pyplot as plt
  3. from abc import ABC, abstractproperty
  4. from nabirds import utils
  5. class BasePartCollection(ABC):
  6. def __getitem__(self, i):
  7. return self._parts[i]
  8. def __repr__(self):
  9. return repr(np.stack([p.as_annotation for p in self._parts]))
  10. @property
  11. def selected(self):
  12. return np.array([p.is_visible for p in self._parts], dtype=bool)
  13. @property
  14. def selected_idxs(self):
  15. return np.where(self.selected)[0]
  16. def select(self, idxs):
  17. if isinstance(idxs, np.ndarray) and idxs.dtype == bool:
  18. # a mask is present, so convert it to indeces
  19. idxs = np.where(idxs)[0]
  20. for p in self._parts:
  21. p.is_visible = p._id in idxs
  22. def hide_outside_bb(self, *bounding_box):
  23. for p in self._parts:
  24. p.hide_if_outside(*bounding_box)
  25. def invert_selection(self):
  26. self.select(np.logical_not(self.selected))
  27. def offset(self, dx, dy):
  28. for p in self._parts:
  29. p.x += dx
  30. p.y += dy
  31. def visible_locs(self):
  32. vis = [(p._id, p.xy) for p in self._parts if p.is_visible]
  33. idxs, xy = zip(*vis)
  34. return np.array(idxs), np.array(xy).T
  35. def visible_crops(self, *args, **kwargs):
  36. return np.array([p.crop(*args, **kwargs) for p in self._parts])
  37. def plot(self, cmap=plt.cm.jet, **kwargs):
  38. for i, p in enumerate(self._parts):
  39. p.plot(color=cmap(i/len(self._parts)), **kwargs)
  40. def reveal(self, im, ratio, *args, **kwargs):
  41. res = np.zeros_like(im)
  42. for part in self._parts:
  43. if not part.is_visible: continue
  44. x, y, crop = part.reveal(im, ratio=ratio, *args, **kwargs)
  45. h, w, _ = crop.shape
  46. res[y:y+h, x:x+w] = crop
  47. return res
  48. class BasePart(ABC):
  49. def __repr__(self):
  50. return repr(self.as_annotation)
  51. @staticmethod
  52. def new(image, annotation, rescale_size=-1):
  53. from .annotation import BBoxPart, LocationPart
  54. if len(annotation) == 4:
  55. return LocationPart(image, annotation, rescale_size)
  56. elif len(annotation) == 5:
  57. return BBoxPart(image, annotation, rescale_size)
  58. else:
  59. raise ValueError("Unknown part annotation format: {}".format(annotation))
  60. def rescale(self, image, annotation, rescale_size):
  61. if rescale_size is not None and rescale_size > 0:
  62. h, w, c = utils.dimensions(image)
  63. scale = np.array([w, h]) / rescale_size
  64. xy = annotation[1:3]
  65. xy = xy * scale
  66. annotation[1:3] = xy
  67. return annotation
  68. @property
  69. def is_visible(self):
  70. return self._is_visible
  71. @is_visible.setter
  72. def is_visible(self, value):
  73. self._is_visible = bool(value)
  74. @property
  75. def xy(self):
  76. return np.array([self.x, self.y])
  77. def crop(self, im, w, h, padding_mode="edge", is_location=True):
  78. if not self.is_visible:
  79. _, _, c = utils.dimensions(im)
  80. return np.zeros((h, w, c), dtype=np.uint8)
  81. x, y = self.xy
  82. pad_h, pad_w = h // 2, w // 2
  83. padded_im = np.pad(im, [(pad_h, pad_h), (pad_w, pad_w), [0,0]], mode=padding_mode)
  84. x0, y0 = x + pad_w, y + pad_h
  85. if is_location:
  86. x0, y0 = x0 - w // 2, y0 - h // 2
  87. return padded_im[y0:y0+h, x0:x0+w]
  88. @abstractproperty
  89. def middle(self):
  90. raise NotImplementedError
  91. def plot(self, **kwargs):
  92. return
  93. def hide_if_outside(self, x, y, w, h):
  94. mid_x, mid_y = self.middle
  95. self.is_visible = ((x <= mid_x <= x+w) and (y <= mid_y <= y+h))