part.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import numpy as np
  2. from contextlib import contextmanager
  3. from abc import ABC, abstractmethod, abstractproperty
  4. from . import utils
  5. class Parts(object):
  6. def __init__(self, image, part_annotations, rescale_size):
  7. super(Parts, self).__init__()
  8. annots = utils.rescale_parts(image, part_annotations, rescale_size)
  9. self._parts = [BasePart.new(a) for a in annots]
  10. self.rescale_size = rescale_size
  11. def __getitem__(self, i):
  12. return self._parts[i]
  13. @property
  14. def selected(self):
  15. return np.array([p.is_visible for p in self._parts], dtype=bool)
  16. @property
  17. def selected_idxs(self):
  18. return np.where(self.selected)[0]
  19. def select(self, idxs):
  20. if isinstance(idxs, np.ndarray) and idxs.dtype == bool:
  21. # a mask is present, so convert it to indeces
  22. idxs = np.where(idxs)[0]
  23. for p in self._parts:
  24. p.is_visible = p._id in idxs
  25. def invert_selection(self):
  26. self.select(np.logical_not(self.selected))
  27. def offset(self, dx, dy):
  28. for p in self._parts:
  29. p.x += dx
  30. p.y += dy
  31. def visible_locs(self):
  32. vis = [(p._id, p.xy) for p in self._parts if p.is_visible]
  33. idxs, xy = zip(*vis)
  34. return np.array(idxs), np.array(xy).T
  35. def visible_crops(self, *args, **kwargs):
  36. return np.array([p.crop(*args, **kwargs) for p in self._parts])
  37. class BasePart(ABC):
  38. def __init__(self, image, annotation):
  39. super(BasePart, self).__init__()
  40. self.image = image
  41. self.read_annotation(annotation)
  42. @staticmethod
  43. def new(image, annotation):
  44. if len(annotation) == 4:
  45. return LocationPart(image, annotation)
  46. elif len(annotation) == 5:
  47. return BBoxPart(image, annotation)
  48. else:
  49. raise ValueError("Unknown part annotation format: {}".format(annotation))
  50. @abstractmethod
  51. def read_annotation(self, annotation):
  52. raise NotImplementedError
  53. @property
  54. def is_visible(self):
  55. return self._is_visible
  56. @is_visible.setter
  57. def is_visible(self, value):
  58. self._is_visible = bool(value)
  59. @property
  60. def xy(self):
  61. return np.array([self.x, self.y])
  62. @property
  63. def c(self):
  64. h, w, c = utils.dimensions(self.image)
  65. return c
  66. @abstractmethod
  67. def crop(self, ratio=None, padding_mode="edge"):
  68. raise NotImplementedError
  69. class LocationPart(BasePart):
  70. def read_annotation(self, annotation):
  71. # here x,y are the center of the part
  72. self._id, self.x, self.y, self.is_visible = annotation
  73. self._ratio = None
  74. @abstractmethod
  75. def crop(self, padding_mode="edge", *args, **kwargs):
  76. raise NotImplementedError
  77. class LocationPart(BasePart):
  78. def read_annotation(self, annotation):
  79. # here x,y are the center of the part
  80. self._id, self.x, self.y, self.is_visible = annotation
  81. self._ratio = None
  82. def crop(self, image, ratio=None, padding_mode="edge", *args, **kwargs):
  83. ratio = ratio or self._ratio
  84. _h, _w, c = utils.dimensions(image)
  85. w, h = int(_w * ratio), int(_h * ratio)
  86. if not self.is_visible:
  87. return np.zeros((h, w, c), dtype=np.uint8)
  88. else:
  89. return utils.crop(image, self.xy, w, h,
  90. padding_mode, is_location=True)
  91. class BBoxPart(BasePart):
  92. def read_annotation(self, annotation):
  93. # here x,y are top left corner of the part
  94. self._id, self.x, self.y, self.w, self.h = annotation
  95. self._is_visible = True
  96. def crop(self, padding_mode="edge", *args, **kwargs):
  97. return utils.crop(self.image, self.xy, self.w, self.h,
  98. padding_mode, is_location=False)