parts.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import numpy as np
  2. from .base import BaseMixin
  3. from . import utils
  4. class BasePartMixin(BaseMixin):
  5. def get_example(self, i):
  6. res = super(BasePartMixin, self).get_example(i)
  7. if len(res) == 2:
  8. # result has only image and label
  9. im, lab = res
  10. parts = None
  11. else:
  12. # result has already parts
  13. im, parts, lab = res
  14. return im, parts, lab
  15. class BBCropMixin(BasePartMixin):
  16. def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs):
  17. super(BBCropMixin, self).__init__(*args, **kwargs)
  18. self.crop_to_bb = crop_to_bb
  19. self.crop_uniform = crop_uniform
  20. def bounding_box(self, i):
  21. bbox = self._get("bounding_box", i)
  22. x,y,w,h = [bbox[attr] for attr in "xywh"]
  23. if self.crop_uniform:
  24. x0 = x + w//2
  25. y0 = y + h//2
  26. crop_size = max(w//2, h//2)
  27. x,y = max(x0 - crop_size, 0), max(y0 - crop_size, 0)
  28. w = h = crop_size * 2
  29. return x,y,w,h
  30. def get_example(self, i):
  31. im, parts, label = super(BBCropMixin, self).get_example(i)
  32. if self.crop_to_bb:
  33. x,y,w,h = self.bounding_box(i)
  34. im = im[y:y+h, x:x+w]
  35. if parts is not None:
  36. parts[:, 1] -= x
  37. parts[:, 2] -= y
  38. return im, parts, label
  39. class PartCropMixin(BasePartMixin):
  40. def __init__(self, return_part_crops=False, *args, **kwargs):
  41. super(PartCropMixin, self).__init__(*args, **kwargs)
  42. self.return_part_crops = return_part_crops
  43. def get_example(self, i):
  44. im, parts, label = super(PartCropMixin, self).get_example(i)
  45. assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
  46. if not self.return_part_crops or parts is None or not hasattr(self, "ratio"):
  47. return im, label
  48. crops = utils.visible_crops(im, parts)
  49. idxs, _ = utils.visible_part_locs(parts)
  50. return crops[idxs], label
  51. class PartRevealMixin(BasePartMixin):
  52. def __init__(self, reveal_visible=False, *args, **kwargs):
  53. super(PartRevealMixin, self).__init__(*args, **kwargs)
  54. self.reveal_visible = reveal_visible
  55. def get_example(self, i):
  56. im, parts, label = super(PartRevealMixin, self).get_example(i)
  57. assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
  58. if not self.reveal_visible or parts is None or not hasattr(self, "ratio"):
  59. return im, label
  60. _, xy = utils.visible_part_locs(parts)
  61. im = utils.reveal_parts(im, xy, ratio=self.ratio)
  62. return im, lab
  63. class UniformPartMixin(BasePartMixin):
  64. def __init__(self, uniform_parts=False, ratio=utils.DEFAULT_RATIO, *args, **kwargs):
  65. super(UniformPartMixin, self).__init__(*args, **kwargs)
  66. self.uniform_parts = uniform_parts
  67. self.ratio = ratio
  68. def get_example(self, i):
  69. im, parts, label = super(UniformPartMixin, self).get_example(i)
  70. if self.uniform_parts:
  71. parts = utils.uniform_parts(im, ratio=self.ratio)
  72. return im, parts, label
  73. class RandomBlackOutMixin(BasePartMixin):
  74. def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
  75. super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
  76. self.rnd = np.random.RandomState(seed)
  77. self.rnd_select = rnd_select
  78. self.n_parts = n_parts
  79. def get_example(self, i):
  80. im, parts, lab = super(RandomBlackOutMixin, self).get_example(i)
  81. if self.rnd_select:
  82. idxs, xy = utils.visible_part_locs(parts)
  83. rnd_idxs = utils.random_idxs(idxs, rnd=self.rnd, n_parts=self.n_parts)
  84. parts[:, -1] = 0
  85. parts[rnd_idxs, -1] = 1
  86. return im, parts, lab
  87. # some shortcuts
  88. class PartMixin(RandomBlackOutMixin, UniformPartMixin, BBCropMixin):
  89. """
  90. TODO!
  91. """
  92. class RevealedPartMixin(PartRevealMixin, PartMixin):
  93. """
  94. TODO!
  95. """
  96. class CroppedPartMixin(PartCropMixin, PartMixin):
  97. """
  98. TODO!
  99. """