parts.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import numpy as np
  2. from . import BaseMixin
  3. class BBCropMixin(BaseMixin):
  4. def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs):
  5. super(BBCropMixin, self).__init__(*args, **kwargs)
  6. self.crop_to_bb = crop_to_bb
  7. self.crop_uniform = crop_uniform
  8. def bounding_box(self, i):
  9. bbox = self._get("bounding_box", i)
  10. x,y,w,h = [bbox[attr] for attr in "xywh"]
  11. if self.crop_uniform:
  12. x0 = x + w//2
  13. y0 = y + h//2
  14. crop_size = max(w//2, h//2)
  15. x,y = max(x0 - crop_size, 0), max(y0 - crop_size, 0)
  16. w = h = crop_size * 2
  17. return x,y,w,h
  18. def get_example(self, i):
  19. im_obj = super(BBCropMixin, self).get_example(i)
  20. if self.crop_to_bb:
  21. bb = self.bounding_box(i)
  22. return im_obj.crop(*bb)
  23. return im_obj
  24. class PartsInBBMixin(BaseMixin):
  25. def __init__(self, parts_in_bb=False, *args, **kwargs):
  26. super(PartsInBBMixin, self).__init__(*args, **kwargs)
  27. self.parts_in_bb = parts_in_bb
  28. def get_example(self, i):
  29. im_obj = super(PartsInBBMixin, self).get_example(i)
  30. if self.parts_in_bb:
  31. bb = self.bounding_box(i)
  32. return im_obj.hide_parts_outside_bb(*bb)
  33. return im_obj
  34. class PartCropMixin(BaseMixin):
  35. def __init__(self, return_part_crops=False, *args, **kwargs):
  36. super(PartCropMixin, self).__init__(*args, **kwargs)
  37. self.return_part_crops = return_part_crops
  38. def get_example(self, i):
  39. im_obj = super(PartCropMixin, self).get_example(i)
  40. if self.return_part_crops:
  41. return im_obj.part_crops(self.ratio)
  42. return im_obj
  43. class PartRevealMixin(BaseMixin):
  44. def __init__(self, reveal_visible=False, *args, **kwargs):
  45. super(PartRevealMixin, self).__init__(*args, **kwargs)
  46. self.reveal_visible = reveal_visible
  47. def get_example(self, i):
  48. im_obj = super(PartRevealMixin, self).get_example(i)
  49. assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
  50. if not self.reveal_visible:
  51. return im_obj.reveal_visible(self.ratio)
  52. return im_obj
  53. class UniformPartMixin(BaseMixin):
  54. def __init__(self, uniform_parts=False, ratio=None, *args, **kwargs):
  55. super(UniformPartMixin, self).__init__(*args, **kwargs)
  56. self.uniform_parts = uniform_parts
  57. self.ratio = ratio
  58. def get_example(self, i):
  59. im_obj = super(UniformPartMixin, self).get_example(i)
  60. if self.uniform_parts:
  61. return im_obj.uniform_parts(self.ratio)
  62. return im_obj
  63. class RandomBlackOutMixin(BaseMixin):
  64. def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
  65. super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
  66. self.rnd = np.random.RandomState(seed)
  67. self.rnd_select = rnd_select
  68. self.n_parts = n_parts
  69. def get_example(self, i):
  70. im_obj = super(RandomBlackOutMixin, self).get_example(i)
  71. if self.rnd_select:
  72. return im_obj.select_random_parts(rnd=self.rnd, n_parts=self.n_parts)
  73. return im_obj
  74. # some shortcuts
  75. class PartMixin(RandomBlackOutMixin, PartsInBBMixin, UniformPartMixin, BBCropMixin):
  76. """
  77. TODO!
  78. """
  79. class RevealedPartMixin(PartRevealMixin, PartMixin):
  80. """
  81. TODO!
  82. """
  83. class CroppedPartMixin(PartCropMixin, PartMixin):
  84. """
  85. TODO!
  86. """