part.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import numpy as np
  2. from matplotlib import pyplot as plt
  3. from contextlib import contextmanager
  4. from matplotlib.patches import Rectangle
  5. from abc import ABC, abstractmethod, abstractproperty
  6. from . import utils
  7. class Parts(object):
  8. def __init__(self, image, part_annotations, rescale_size):
  9. super(Parts, self).__init__()
  10. self._parts = [BasePart.new(image, a, rescale_size) for a in part_annotations]
  11. def __getitem__(self, i):
  12. return self._parts[i]
  13. @property
  14. def selected(self):
  15. return np.array([p.is_visible for p in self._parts], dtype=bool)
  16. @property
  17. def selected_idxs(self):
  18. return np.where(self.selected)[0]
  19. def select(self, idxs):
  20. if isinstance(idxs, np.ndarray) and idxs.dtype == bool:
  21. # a mask is present, so convert it to indeces
  22. idxs = np.where(idxs)[0]
  23. for p in self._parts:
  24. p.is_visible = p._id in idxs
  25. def invert_selection(self):
  26. self.select(np.logical_not(self.selected))
  27. def offset(self, dx, dy):
  28. for p in self._parts:
  29. p.x += dx
  30. p.y += dy
  31. def visible_locs(self):
  32. vis = [(p._id, p.xy) for p in self._parts if p.is_visible]
  33. idxs, xy = zip(*vis)
  34. return np.array(idxs), np.array(xy).T
  35. def visible_crops(self, *args, **kwargs):
  36. return np.array([p.crop(*args, **kwargs) for p in self._parts])
  37. def plot(self, cmap=plt.cm.jet, **kwargs):
  38. for i, p in enumerate(self._parts):
  39. p.plot(color=cmap(i/len(self._parts)), **kwargs)
  40. def reveal(self, im, ratio, *args, **kwargs):
  41. res = np.zeros_like(im)
  42. for part in self._parts:
  43. if not part.is_visible: continue
  44. x, y, crop = part.reveal(im, ratio=ratio, *args, **kwargs)
  45. h, w, _ = crop.shape
  46. res[y:y+h, x:x+w] = crop
  47. return res
  48. class BasePart(ABC):
  49. @staticmethod
  50. def new(image, annotation, rescale_size=-1):
  51. if len(annotation) == 4:
  52. return LocationPart(image, annotation, rescale_size)
  53. elif len(annotation) == 5:
  54. return BBoxPart(image, annotation, rescale_size)
  55. else:
  56. raise ValueError("Unknown part annotation format: {}".format(annotation))
  57. def rescale(self, image, annotation, rescale_size):
  58. if rescale_size is not None and rescale_size > 0:
  59. h, w, c = utils.dimensions(image)
  60. scale = np.array([w, h]) / rescale_size
  61. xy = annotation[1:3]
  62. xy = xy * scale
  63. annotation[1:3] = xy
  64. return annotation
  65. @property
  66. def is_visible(self):
  67. return self._is_visible
  68. @is_visible.setter
  69. def is_visible(self, value):
  70. self._is_visible = bool(value)
  71. @property
  72. def xy(self):
  73. return np.array([self.x, self.y])
  74. def crop(self, im, w, h, padding_mode="edge", is_location=True):
  75. x, y = self.xy
  76. pad_h, pad_w = h // 2, w // 2
  77. padded_im = np.pad(im, [(pad_h, pad_h), (pad_w, pad_w), [0,0]], mode=padding_mode)
  78. x0, y0 = x + pad_w, y + pad_h
  79. if is_location:
  80. x0, y0 = x0 - w // 2, y0 - h // 2
  81. return padded_im[y0:y0+h, x0:x0+w]
  82. def plot(self, **kwargs):
  83. return
  84. class LocationPart(BasePart):
  85. DEFAULT_RATIO = np.sqrt(49 / 400) # 0.35
  86. def __init__(self, image, annotation, rescale_size):
  87. super(LocationPart, self).__init__()
  88. annotation = self.rescale(image, annotation, rescale_size)
  89. # here x,y are the center of the part
  90. self._id, self.x, self.y, self.is_visible = annotation
  91. self._ratio = LocationPart.DEFAULT_RATIO
  92. def crop(self, image, ratio=None, padding_mode="edge", *args, **kwargs):
  93. ratio = ratio or self._ratio
  94. _h, _w, c = utils.dimensions(image)
  95. w, h = int(_w * ratio), int(_h * ratio)
  96. if not self.is_visible:
  97. return np.zeros((h, w, c), dtype=np.uint8)
  98. else:
  99. return super(LocationPart, self).crop(image, w, h,
  100. padding_mode, is_location=True)
  101. def reveal(self, im, ratio=None, *args, **kwargs):
  102. _h, _w, c = utils.dimensions(im)
  103. w, h = int(_w * ratio), int(_h * ratio)
  104. x,y = self.xy
  105. x, y = max(x - w // 2, 0), max(y - h // 2, 0)
  106. return x, y, im[y:y+h, x:x+w]
  107. def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
  108. if not self.is_visible: return
  109. x, y = self.xy
  110. _h, _w, c = utils.dimensions(im)
  111. w, h = int(_w * ratio), int(_h * ratio)
  112. ax.add_patch(Rectangle(
  113. (x-w//2, y-h//2), w, h,
  114. fill=fill, linestyle=linestyle,
  115. **kwargs
  116. ))
  117. class BBoxPart(BasePart):
  118. def __init__(self, image, annotation, rescale_size):
  119. super(BBoxPart, self).__init__()
  120. annotation = self.rescale(image, annotation, rescale_size)
  121. # here x,y are top left corner of the part
  122. self._id, self.x, self.y, self.w, self.h = annotation
  123. self.is_visible = True
  124. def rescale(self, image, annotation, rescale_size):
  125. if rescale_size is not None and rescale_size > 0:
  126. annotation = super(BBoxPart, self).rescale(image, annotation, rescale_size)
  127. h, w, c = utils.dimensions(image)
  128. scale = np.array([w, h]) / rescale_size
  129. wh = annotation[3:5]
  130. wh = wh * scale
  131. annotation[3:5] = wh
  132. return annotation
  133. def crop(self, image, padding_mode="edge", *args, **kwargs):
  134. return super(BBoxPart, self).crop(image, self.w, self.h,
  135. padding_mode, is_location=False)
  136. def reveal(self, im, ratio, *args, **kwargs):
  137. _h, _w, c = utils.dimensions(im)
  138. x,y = self.xy
  139. return x, y, im[y:y+self.h, x:x+self.w]
  140. def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
  141. ax.add_patch(Rectangle(
  142. (self.x, self.y), self.w, self.h,
  143. fill=fill, linestyle=linestyle,
  144. **kwargs
  145. ))