parts.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import numpy as np
  2. from cvdatasets.dataset.mixins.base import BaseMixin
  3. class BBoxMixin(BaseMixin):
  4. def bounding_box(self, i):
  5. bbox = self._get("bounding_box", i)
  6. return [bbox[attr] for attr in "xywh"]
  7. class MultiBoxMixin(BaseMixin):
  8. _all_keys=[
  9. "x", "x0", "x1",
  10. "y", "y0", "y1",
  11. "w", "h",
  12. ]
  13. def multi_box(self, i, keys=["x0","x1","y0","y1"]):
  14. assert all([key in self._all_keys for key in keys]), \
  15. f"unknown keys found: {keys}. Possible are: {self._all_keys}"
  16. boxes = [
  17. dict(
  18. x=box["x0"], x0=box["x0"], x1=box["x1"],
  19. y=box["y0"], y0=box["y0"], y1=box["y1"],
  20. w=box["x1"] - box["x0"],
  21. h=box["y1"] - box["y0"],
  22. )
  23. for box in self._get("multi_box", i)["objects"]
  24. ]
  25. return [[box[key] for key in keys] for box in boxes]
  26. class BBCropMixin(BBoxMixin):
  27. def __init__(self, *, crop_to_bb=False, crop_uniform=False, **kwargs):
  28. super(BBCropMixin, self).__init__(**kwargs)
  29. self.crop_to_bb = crop_to_bb
  30. self.crop_uniform = crop_uniform
  31. def bounding_box(self, i):
  32. x,y,w,h = super(BBCropMixin, self).bounding_box(i)
  33. if self.crop_uniform:
  34. x0 = x + w//2
  35. y0 = y + h//2
  36. crop_size = max(w//2, h//2)
  37. x,y = max(x0 - crop_size, 0), max(y0 - crop_size, 0)
  38. w = h = crop_size * 2
  39. return x,y,w,h
  40. def get_example(self, i):
  41. im_obj = super(BBCropMixin, self).get_example(i)
  42. if self.crop_to_bb:
  43. bb = self.bounding_box(i)
  44. return im_obj.crop(*bb)
  45. return im_obj
  46. class PartsInBBMixin(BBoxMixin):
  47. def __init__(self, parts_in_bb=False, *args, **kwargs):
  48. super(PartsInBBMixin, self).__init__(*args, **kwargs)
  49. self.parts_in_bb = parts_in_bb
  50. def get_example(self, i):
  51. im_obj = super(PartsInBBMixin, self).get_example(i)
  52. if self.parts_in_bb:
  53. bb = self.bounding_box(i)
  54. return im_obj.hide_parts_outside_bb(*bb)
  55. return im_obj
  56. class PartCropMixin(BaseMixin):
  57. def __init__(self, return_part_crops=False, *args, **kwargs):
  58. super(PartCropMixin, self).__init__(*args, **kwargs)
  59. self.return_part_crops = return_part_crops
  60. def get_example(self, i):
  61. im_obj = super(PartCropMixin, self).get_example(i)
  62. if self.return_part_crops:
  63. return im_obj.part_crops(self.ratio)
  64. return im_obj
  65. class PartRevealMixin(BaseMixin):
  66. def __init__(self, reveal_visible=False, *args, **kwargs):
  67. super(PartRevealMixin, self).__init__(*args, **kwargs)
  68. self.reveal_visible = reveal_visible
  69. def get_example(self, i):
  70. im_obj = super(PartRevealMixin, self).get_example(i)
  71. assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
  72. if self.reveal_visible:
  73. return im_obj.reveal_visible(self.ratio)
  74. return im_obj
  75. class UniformPartMixin(BaseMixin):
  76. def __init__(self, uniform_parts=False, ratio=None, *args, **kwargs):
  77. super(UniformPartMixin, self).__init__(*args, **kwargs)
  78. self.uniform_parts = uniform_parts
  79. self.ratio = ratio
  80. def get_example(self, i):
  81. im_obj = super(UniformPartMixin, self).get_example(i)
  82. if self.uniform_parts:
  83. return im_obj.uniform_parts(self.ratio)
  84. return im_obj
  85. class RandomBlackOutMixin(BaseMixin):
  86. def __init__(self, seed=None, rnd_select=False, blackout_parts=None, *args, **kwargs):
  87. super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
  88. self.rnd = np.random.RandomState(seed)
  89. self.rnd_select = rnd_select
  90. self.blackout_parts = blackout_parts
  91. def get_example(self, i):
  92. im_obj = super(RandomBlackOutMixin, self).get_example(i)
  93. if self.rnd_select:
  94. return im_obj.select_random_parts(rnd=self.rnd, n_parts=self.blackout_parts)
  95. return im_obj
  96. # some shortcuts
  97. class PartMixin(RandomBlackOutMixin, PartsInBBMixin, UniformPartMixin, BBCropMixin):
  98. """
  99. TODO!
  100. """
  101. class RevealedPartMixin(PartRevealMixin, PartMixin):
  102. """
  103. TODO!
  104. """
  105. class CroppedPartMixin(PartCropMixin, PartMixin):
  106. """
  107. TODO!
  108. """