|
@@ -0,0 +1,134 @@
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+from .base import BaseMixin
|
|
|
+from . import utils
|
|
|
+
|
|
|
+
|
|
|
+class BasePartMixin(BaseMixin):
|
|
|
+
|
|
|
+ def get_example(self, i):
|
|
|
+ res = super(BasePartMixin, self).get_example(i)
|
|
|
+ if len(res) == 2:
|
|
|
+ # result has only image and label
|
|
|
+ im, lab = res
|
|
|
+ parts = None
|
|
|
+ else:
|
|
|
+ # result has already parts
|
|
|
+ im, parts, lab = res
|
|
|
+
|
|
|
+ return im, parts, lab
|
|
|
+
|
|
|
+class BBCropMixin(BasePartMixin):
|
|
|
+
|
|
|
+ def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs):
|
|
|
+ super(BBCropMixin, self).__init__(*args, **kwargs)
|
|
|
+ self.crop_to_bb = crop_to_bb
|
|
|
+ self.crop_uniform = crop_uniform
|
|
|
+
|
|
|
+ def bounding_box(self, i):
|
|
|
+ bbox = self._get("bounding_box", i)
|
|
|
+ x,y,w,h = [bbox[attr] for attr in "xywh"]
|
|
|
+ if self.crop_uniform:
|
|
|
+ x0 = x + w//2
|
|
|
+ y0 = y + h//2
|
|
|
+
|
|
|
+ crop_size = max(w//2, h//2)
|
|
|
+
|
|
|
+ x,y = max(x0 - crop_size, 0), max(y0 - crop_size, 0)
|
|
|
+ w = h = crop_size * 2
|
|
|
+ return x,y,w,h
|
|
|
+
|
|
|
+ def get_example(self, i):
|
|
|
+ im, parts, label = super(BBCropMixin, self).get_example(i)
|
|
|
+ if self.crop_to_bb:
|
|
|
+ x,y,w,h = self.bounding_box(i)
|
|
|
+ im = im[y:y+h, x:x+w]
|
|
|
+ if parts is not None:
|
|
|
+ parts[:, 1] -= x
|
|
|
+ parts[:, 2] -= y
|
|
|
+ return im, parts, label
|
|
|
+
|
|
|
+class PartCropMixin(BasePartMixin):
|
|
|
+ def __init__(self, return_part_crops=False, *args, **kwargs):
|
|
|
+ super(PartCropMixin, self).__init__(*args, **kwargs)
|
|
|
+ self.return_part_crops = return_part_crops
|
|
|
+
|
|
|
+ def get_example(self, i):
|
|
|
+ im, parts, label = super(PartCropMixin, self).get_example(i)
|
|
|
+ assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
|
|
|
+ if not self.return_part_crops or parts is None or not hasattr(self, "ratio"):
|
|
|
+ return im, label
|
|
|
+
|
|
|
+ crops = utils.visible_crops(im, parts)
|
|
|
+ idxs, _ = utils.visible_part_locs(parts)
|
|
|
+
|
|
|
+ return crops[idxs], label
|
|
|
+
|
|
|
+
|
|
|
+class PartRevealMixin(BasePartMixin):
|
|
|
+ def __init__(self, reveal_visible=False, *args, **kwargs):
|
|
|
+ super(PartRevealMixin, self).__init__(*args, **kwargs)
|
|
|
+ self.reveal_visible = reveal_visible
|
|
|
+
|
|
|
+ def get_example(self, i):
|
|
|
+ im, parts, label = super(PartRevealMixin, self).get_example(i)
|
|
|
+ assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
|
|
|
+ if not self.reveal_visible or parts is None or not hasattr(self, "ratio"):
|
|
|
+ return im, label
|
|
|
+
|
|
|
+ _, xy = utils.visible_part_locs(parts)
|
|
|
+ im = utils.reveal_parts(im, xy, ratio=self.ratio)
|
|
|
+ return im, lab
|
|
|
+
|
|
|
+
|
|
|
+class UniformPartMixin(BasePartMixin):
|
|
|
+
|
|
|
+ def __init__(self, uniform_parts=False, ratio=utils.DEFAULT_RATIO, *args, **kwargs):
|
|
|
+ super(UniformPartMixin, self).__init__(*args, **kwargs)
|
|
|
+ self.uniform_parts = uniform_parts
|
|
|
+ self.ratio = ratio
|
|
|
+
|
|
|
+ def get_example(self, i):
|
|
|
+ im, parts, label = super(UniformPartMixin, self).get_example(i)
|
|
|
+ if self.uniform_parts:
|
|
|
+ parts = utils.uniform_parts(im, ratio=self.ratio)
|
|
|
+ return im, parts, label
|
|
|
+
|
|
|
+class RandomBlackOutMixin(BasePartMixin):
|
|
|
+
|
|
|
+ def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
|
|
|
+ super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
|
|
|
+ self.rnd = np.random.RandomState(seed)
|
|
|
+ self.rnd_select = rnd_select
|
|
|
+ self.n_parts = n_parts
|
|
|
+
|
|
|
+ def get_example(self, i):
|
|
|
+ im, parts, lab = super(RandomBlackOutMixin, self).get_example(i)
|
|
|
+ if self.rnd_select:
|
|
|
+ idxs, xy = utils.visible_part_locs(parts)
|
|
|
+ rnd_idxs = utils.random_idxs(idxs, rnd=self.rnd, n_parts=self.n_parts)
|
|
|
+
|
|
|
+ parts[:, -1] = 0
|
|
|
+ parts[rnd_idxs, -1] = 1
|
|
|
+
|
|
|
+ return im, parts, lab
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# some shortcuts
|
|
|
+
|
|
|
+class PartMixin(RandomBlackOutMixin, UniformPartMixin, BBCropMixin):
|
|
|
+ """
|
|
|
+ TODO!
|
|
|
+ """
|
|
|
+
|
|
|
+class RevealedPartMixin(PartRevealMixin, PartMixin):
|
|
|
+ """
|
|
|
+ TODO!
|
|
|
+ """
|
|
|
+
|
|
|
+
|
|
|
+class CroppedPartMixin(PartCropMixin, PartMixin):
|
|
|
+ """
|
|
|
+ TODO!
|
|
|
+ """
|