part.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import numpy as np
  2. from matplotlib import pyplot as plt
  3. from contextlib import contextmanager
  4. from matplotlib.patches import Rectangle
  5. from abc import ABC, abstractmethod, abstractproperty
  6. from . import utils
  7. class Parts(object):
  8. def __init__(self, image, part_annotations, rescale_size):
  9. super(Parts, self).__init__()
  10. self._parts = [BasePart.new(image, a, rescale_size) for a in part_annotations]
  11. def __getitem__(self, i):
  12. return self._parts[i]
  13. @property
  14. def selected(self):
  15. return np.array([p.is_visible for p in self._parts], dtype=bool)
  16. @property
  17. def selected_idxs(self):
  18. return np.where(self.selected)[0]
  19. def select(self, idxs):
  20. if isinstance(idxs, np.ndarray) and idxs.dtype == bool:
  21. # a mask is present, so convert it to indeces
  22. idxs = np.where(idxs)[0]
  23. for p in self._parts:
  24. p.is_visible = p._id in idxs
  25. def invert_selection(self):
  26. self.select(np.logical_not(self.selected))
  27. def offset(self, dx, dy):
  28. for p in self._parts:
  29. p.x += dx
  30. p.y += dy
  31. def visible_locs(self):
  32. vis = [(p._id, p.xy) for p in self._parts if p.is_visible]
  33. idxs, xy = zip(*vis)
  34. return np.array(idxs), np.array(xy).T
  35. def visible_crops(self, *args, **kwargs):
  36. return np.array([p.crop(*args, **kwargs) for p in self._parts])
  37. def plot(self, cmap=plt.cm.jet, **kwargs):
  38. for i, p in enumerate(self._parts):
  39. p.plot(color=cmap(i/len(self._parts)), **kwargs)
  40. def reveal(self, im, ratio, *args, **kwargs):
  41. res = np.zeros_like(im)
  42. for part in self._parts:
  43. if not part.is_visible: continue
  44. x, y, crop = part.reveal(im, ratio=ratio, *args, **kwargs)
  45. h, w, _ = crop.shape
  46. res[y:y+h, x:x+w] = crop
  47. return res
  48. class BasePart(ABC):
  49. @staticmethod
  50. def new(image, annotation, rescale_size=-1):
  51. if len(annotation) == 4:
  52. return LocationPart(image, annotation, rescale_size)
  53. elif len(annotation) == 5:
  54. return BBoxPart(image, annotation, rescale_size)
  55. else:
  56. raise ValueError("Unknown part annotation format: {}".format(annotation))
  57. @abstractmethod
  58. def __init__(self, *args, **kwargs):
  59. raise NotImplementedError
  60. def rescale(self, image, annotation, rescale_size):
  61. if rescale_size is not None and rescale_size > 0:
  62. h, w, c = utils.dimensions(image)
  63. scale = np.array([w, h]) / rescale_size
  64. xy = annotation[1:3]
  65. xy = xy * scale
  66. annotation[1:3] = xy
  67. return annotation
  68. @property
  69. def is_visible(self):
  70. return self._is_visible
  71. @is_visible.setter
  72. def is_visible(self, value):
  73. self._is_visible = bool(value)
  74. @property
  75. def xy(self):
  76. return np.array([self.x, self.y])
  77. def crop(self, im, w, h, padding_mode="edge", is_location=True):
  78. x, y = self.xy
  79. pad_h, pad_w = h // 2, w // 2
  80. padded_im = np.pad(im, [(pad_h, pad_h), (pad_w, pad_w), [0,0]], mode=padding_mode)
  81. x0, y0 = x + pad_w, y + pad_h
  82. if is_location:
  83. x0, y0 = x0 - w // 2, y0 - h // 2
  84. return padded_im[y0:y0+h, x0:x0+w]
  85. def plot(self, **kwargs):
  86. return
  87. class LocationPart(BasePart):
  88. DEFAULT_RATIO = np.sqrt(49 / 400) # 0.35
  89. def __init__(self, image, annotation, rescale_size):
  90. annotation = self.rescale(image, annotation, rescale_size)
  91. # here x,y are the center of the part
  92. self._id, self.x, self.y, self.is_visible = annotation
  93. self._ratio = LocationPart.DEFAULT_RATIO
  94. def crop(self, image, ratio=None, padding_mode="edge", *args, **kwargs):
  95. ratio = ratio or self._ratio
  96. _h, _w, c = utils.dimensions(image)
  97. w, h = int(_w * ratio), int(_h * ratio)
  98. if not self.is_visible:
  99. return np.zeros((h, w, c), dtype=np.uint8)
  100. else:
  101. return super(LocationPart, self).crop(image, w, h,
  102. padding_mode, is_location=True)
  103. def reveal(self, im, ratio=None, *args, **kwargs):
  104. _h, _w, c = utils.dimensions(im)
  105. w, h = int(_w * ratio), int(_h * ratio)
  106. x,y = self.xy
  107. x, y = max(x - w // 2, 0), max(y - h // 2, 0)
  108. return x, y, im[y:y+h, x:x+w]
  109. def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
  110. if not self.is_visible: return
  111. x, y = self.xy
  112. _h, _w, c = utils.dimensions(im)
  113. w, h = int(_w * ratio), int(_h * ratio)
  114. ax.add_patch(Rectangle(
  115. (x-w//2, y-h//2), w, h,
  116. fill=fill, linestyle=linestyle,
  117. **kwargs
  118. ))
  119. class BBoxPart(BasePart):
  120. def __init__(self, image, annotation, rescale_size):
  121. annotation = self.rescale(image, annotation, rescale_size)
  122. # here x,y are top left corner of the part
  123. self._id, self.x, self.y, self.w, self.h = annotation
  124. self.is_visible = True
  125. def rescale(self, image, annotation, rescale_size):
  126. if rescale_size is not None and rescale_size > 0:
  127. annotation = super(BBoxPart, self).rescale(image, annotation, rescale_size)
  128. h, w, c = utils.dimensions(image)
  129. scale = np.array([w, h]) / rescale_size
  130. wh = annotation[3:5]
  131. wh = wh * scale
  132. annotation[3:5] = wh
  133. return annotation
  134. def crop(self, image, padding_mode="edge", *args, **kwargs):
  135. return super(BBoxPart, self).crop(image, self.w, self.h,
  136. padding_mode, is_location=False)
  137. def reveal(self, im, ratio, *args, **kwargs):
  138. _h, _w, c = utils.dimensions(im)
  139. x,y = self.xy
  140. return x, y, im[y:y+self.h, x:x+self.w]
  141. def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
  142. ax.add_patch(Rectangle(
  143. (self.x, self.y), self.w, self.h,
  144. fill=fill, linestyle=linestyle,
  145. **kwargs
  146. ))