base.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import numpy as np
  2. from abc import ABC
  3. from abc import abstractproperty
  4. from matplotlib import pyplot as plt
  5. from skimage.transform import resize
  6. from cvdatasets import utils
  7. class BasePartCollection(ABC):
  8. def __getitem__(self, i):
  9. return self._parts[i]
  10. def __len__(self, i):
  11. return len(self._parts)
  12. def __repr__(self):
  13. return repr(np.stack([p.as_annotation for p in self._parts]))
  14. @property
  15. def selected(self):
  16. return np.array([p.is_visible for p in self._parts], dtype=bool)
  17. @property
  18. def selected_idxs(self):
  19. return np.where(self.selected)[0]
  20. def select(self, idxs):
  21. if isinstance(idxs, np.ndarray) and idxs.dtype == bool:
  22. # a mask is present, so convert it to indeces
  23. idxs = np.where(idxs)[0]
  24. for p in self._parts:
  25. p.is_visible = p._id in idxs
  26. def hide_outside_bb(self, *bounding_box):
  27. for p in self._parts:
  28. p.hide_if_outside(*bounding_box)
  29. def invert_selection(self):
  30. self.select(np.logical_not(self.selected))
  31. def offset(self, dx, dy):
  32. for p in self._parts:
  33. p.x += dx
  34. p.y += dy
  35. def visible_locs(self):
  36. vis = [(p._id, p.xy) for p in self._parts if p.is_visible]
  37. idxs, xy = zip(*vis)
  38. return np.array(idxs), np.array(xy).T
  39. def visible_crops(self, *args, **kwargs):
  40. crops = [p.crop(*args, **kwargs) for p in self._parts]
  41. return crops
  42. # return np.array(crops)
  43. def plot(self, cmap=plt.cm.jet, **kwargs):
  44. for i, p in enumerate(self._parts):
  45. p.plot(color=cmap(i/len(self._parts)), **kwargs)
  46. def reveal(self, im, ratio, *args, **kwargs):
  47. res = np.zeros_like(im)
  48. for part in self._parts:
  49. if not part.is_visible: continue
  50. x, y, crop = part.reveal(im, ratio=ratio, *args, **kwargs)
  51. h, w, _ = crop.shape
  52. res[y:y+h, x:x+w] = crop
  53. return res
  54. class BasePart(ABC):
  55. def __repr__(self):
  56. return repr(self.as_annotation)
  57. @staticmethod
  58. def new(image, annotation, rescale_size=-1, center_cropped=True):
  59. from .annotation import BBoxPart, LocationPart
  60. if len(annotation) == 4:
  61. return LocationPart(image, annotation, rescale_size)
  62. elif len(annotation) == 5:
  63. return BBoxPart(image, annotation, rescale_size, center_cropped)
  64. else:
  65. raise ValueError("Unknown part annotation format: {}".format(annotation))
  66. def rescale(self, image, annotation, rescale_size, center_cropped=True):
  67. if rescale_size is not None and rescale_size > 0:
  68. xy = annotation[1:3]
  69. new_xy = utils.rescale(image, xy, rescale_size, center_cropped)
  70. annotation[1:3] = new_xy
  71. return annotation
  72. @property
  73. def is_visible(self):
  74. return self._is_visible
  75. @is_visible.setter
  76. def is_visible(self, value):
  77. self._is_visible = bool(value)
  78. @property
  79. def xy(self):
  80. return np.array([self.x, self.y])
  81. def crop(self, im, w, h, padding_mode="edge", is_location=True):
  82. if not self.is_visible:
  83. _part_surrogate = resize(utils.asarray(im), (h, w),
  84. mode="constant",
  85. anti_aliasing=True,
  86. preserve_range=True)
  87. return _part_surrogate.astype(np.uint8)
  88. # old code using black images as surrogates
  89. _, _, c = utils.dimensions(im)
  90. return np.zeros((h, w, c), dtype=np.uint8)
  91. x, y = self.xy
  92. pad_h, pad_w = h // 2, w // 2
  93. padded_im = np.pad(im, [(pad_h, pad_h), (pad_w, pad_w), [0,0]], mode=padding_mode)
  94. x0, y0 = x + pad_w, y + pad_h
  95. if is_location:
  96. x0, y0 = x0 - w // 2, y0 - h // 2
  97. return padded_im[y0:y0+h, x0:x0+w]
  98. @abstractproperty
  99. def middle(self):
  100. raise NotImplementedError
  101. def plot(self, **kwargs):
  102. return
  103. def hide_if_outside(self, x, y, w, h):
  104. mid_x, mid_y = self.middle
  105. self.is_visible = ((x <= mid_x <= x+w) and (y <= mid_y <= y+h))