part.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import numpy as np
  2. from matplotlib import pyplot as plt
  3. from contextlib import contextmanager
  4. from matplotlib.patches import Rectangle
  5. from abc import ABC, abstractmethod, abstractproperty
  6. from . import utils
  7. class Parts(object):
  8. def __init__(self, image, part_annotations, rescale_size):
  9. super(Parts, self).__init__()
  10. annots = utils.rescale_parts(image, part_annotations, rescale_size)
  11. self._parts = [BasePart.new(a) for a in annots]
  12. self.rescale_size = rescale_size
  13. def __getitem__(self, i):
  14. return self._parts[i]
  15. @property
  16. def selected(self):
  17. return np.array([p.is_visible for p in self._parts], dtype=bool)
  18. @property
  19. def selected_idxs(self):
  20. return np.where(self.selected)[0]
  21. def select(self, idxs):
  22. if isinstance(idxs, np.ndarray) and idxs.dtype == bool:
  23. # a mask is present, so convert it to indeces
  24. idxs = np.where(idxs)[0]
  25. for p in self._parts:
  26. p.is_visible = p._id in idxs
  27. def invert_selection(self):
  28. self.select(np.logical_not(self.selected))
  29. def offset(self, dx, dy):
  30. for p in self._parts:
  31. p.x += dx
  32. p.y += dy
  33. def visible_locs(self):
  34. vis = [(p._id, p.xy) for p in self._parts if p.is_visible]
  35. idxs, xy = zip(*vis)
  36. return np.array(idxs), np.array(xy).T
  37. def visible_crops(self, *args, **kwargs):
  38. return np.array([p.crop(*args, **kwargs) for p in self._parts])
  39. def plot(self, cmap=plt.cm.jet, **kwargs):
  40. for i, p in enumerate(self._parts):
  41. p.plot(color=cmap(i/len(self._parts)), **kwargs)
  42. def reveal(self, im, ratio, *args, **kwargs):
  43. res = np.zeros_like(im)
  44. for part in self._parts:
  45. if not part.is_visible: continue
  46. x, y, crop = part.reveal(im, ratio=ratio, *args, **kwargs)
  47. h, w, _ = crop.shape
  48. res[y:y+h, x:x+w] = crop
  49. return res
  50. class BasePart(ABC):
  51. def __init__(self, annotation):
  52. super(BasePart, self).__init__()
  53. self.read_annotation(annotation)
  54. @staticmethod
  55. def new(annotation):
  56. if len(annotation) == 4:
  57. return LocationPart(annotation)
  58. elif len(annotation) == 5:
  59. return BBoxPart(annotation)
  60. else:
  61. raise ValueError("Unknown part annotation format: {}".format(annotation))
  62. @abstractmethod
  63. def read_annotation(self, annotation):
  64. raise NotImplementedError
  65. @property
  66. def is_visible(self):
  67. return self._is_visible
  68. @is_visible.setter
  69. def is_visible(self, value):
  70. self._is_visible = bool(value)
  71. @property
  72. def xy(self):
  73. return np.array([self.x, self.y])
  74. @abstractmethod
  75. def crop(self, *args, **kwargs):
  76. raise NotImplementedError
  77. def plot(self, **kwargs):
  78. return
  79. class LocationPart(BasePart):
  80. def read_annotation(self, annotation):
  81. # here x,y are the center of the part
  82. self._id, self.x, self.y, self.is_visible = annotation
  83. self._ratio = None
  84. def crop(self, image, ratio=None, padding_mode="edge", *args, **kwargs):
  85. ratio = ratio or self._ratio
  86. _h, _w, c = utils.dimensions(image)
  87. w, h = int(_w * ratio), int(_h * ratio)
  88. if not self.is_visible:
  89. return np.zeros((h, w, c), dtype=np.uint8)
  90. else:
  91. return utils.crop(image, self.xy, w, h,
  92. padding_mode, is_location=True)
  93. def reveal(self, im, ratio, *args, **kwargs):
  94. _h, _w, c = utils.dimensions(im)
  95. w, h = int(_w * ratio), int(_h * ratio)
  96. x,y = self.xy
  97. x, y = max(x - w // 2, 0), max(y - h // 2, 0)
  98. return x, y, im[y:y+h, x:x+w]
  99. def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
  100. if not self.is_visible: return
  101. x, y = self.xy
  102. _h, _w, c = utils.dimensions(im)
  103. w, h = int(_w * ratio), int(_h * ratio)
  104. ax.add_patch(Rectangle(
  105. (x-w//2, y-h//2), w, h,
  106. fill=fill, linestyle=linestyle,
  107. **kwargs
  108. ))
  109. class BBoxPart(BasePart):
  110. def read_annotation(self, annotation):
  111. # here x,y are top left corner of the part
  112. self._id, self.x, self.y, self.w, self.h = annotation
  113. self._is_visible = True
  114. def crop(self, image, padding_mode="edge", *args, **kwargs):
  115. return utils.crop(image, self.xy, self.w, self.h,
  116. padding_mode, is_location=False)
  117. def reveal(self, im, ratio, *args, **kwargs):
  118. _h, _w, c = utils.dimensions(im)
  119. x,y = self.xy
  120. return x, y, im[y:y+self.h, x:x+self.w]
  121. def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
  122. ax.add_patch(Rectangle(
  123. (self.x, self.y), self.w, self.h,
  124. fill=fill, linestyle=linestyle,
  125. **kwargs
  126. ))