|
@@ -1,12 +1,50 @@
|
|
-from imageio import imread
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
|
|
|
+from abc import ABC, abstractmethod
|
|
|
|
+from imageio import imread
|
|
|
|
+
|
|
|
|
+from . import utils
|
|
|
|
+
|
|
|
|
+class BaseMixin(ABC):
|
|
|
|
+
|
|
|
|
+ @abstractmethod
|
|
|
|
+ def get_example(self, i):
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ def __getitem__(self, i):
|
|
|
|
+ return self.get_example(i)
|
|
|
|
|
|
-class Dataset(object):
|
|
|
|
- def __init__(self, uuids, annotations, crop_to_bb=False, crop_uniform=False):
|
|
|
|
- super(Dataset, self).__init__()
|
|
|
|
|
|
+
|
|
|
|
+class ImageReadMixin(BaseMixin):
|
|
|
|
+
|
|
|
|
+ def __init__(self, uuids, annotations, mode="RGB"):
|
|
|
|
+ super(BaseMixin, self).__init__()
|
|
self.uuids = uuids
|
|
self.uuids = uuids
|
|
self._annot = annotations
|
|
self._annot = annotations
|
|
|
|
+ self.mode = mode
|
|
|
|
+
|
|
|
|
+ def __len__(self):
|
|
|
|
+ return len(self.uuids)
|
|
|
|
+
|
|
|
|
+ def _get(self, method, i):
|
|
|
|
+ return getattr(self._annot, method)(self.uuids[i])
|
|
|
|
+
|
|
|
|
+ def get_example(self, i):
|
|
|
|
+ res = super(ImageReadMixin, self).get_example(i)
|
|
|
|
+ # if the super class returns something, then the class inheritance is wrong
|
|
|
|
+ assert res is None, "ImageReadMixin should be the last class in the hierarchy!"
|
|
|
|
+
|
|
|
|
+ methods = ["image", "parts", "label"]
|
|
|
|
+ im_path, parts, label = [self._get(m, i) for m in methods]
|
|
|
|
+
|
|
|
|
+ im = imread(im_path, pilmode=self.mode)
|
|
|
|
+
|
|
|
|
+ return im, parts, label
|
|
|
|
+
|
|
|
|
+class BBCropMixin(BaseMixin):
|
|
|
|
+
|
|
|
|
+ def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs):
|
|
|
|
+ super(BBCropMixin, self).__init__(*args, **kwargs)
|
|
self.crop_to_bb = crop_to_bb
|
|
self.crop_to_bb = crop_to_bb
|
|
self.crop_uniform = crop_uniform
|
|
self.crop_uniform = crop_uniform
|
|
|
|
|
|
@@ -23,29 +61,46 @@ class Dataset(object):
|
|
w = h = crop_size * 2
|
|
w = h = crop_size * 2
|
|
return x,y,w,h
|
|
return x,y,w,h
|
|
|
|
|
|
- def __len__(self):
|
|
|
|
- return len(self.uuids)
|
|
|
|
-
|
|
|
|
- def _get(self, method, i):
|
|
|
|
- return getattr(self._annot, method)(self.uuids[i])
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- def get_example(self, i, mode="RGB"):
|
|
|
|
- methods = ["image", "parts", "label"]
|
|
|
|
- im_path, parts, label = [self._get(m, i) for m in methods]
|
|
|
|
-
|
|
|
|
- im = imread(im_path, pilmode=mode)
|
|
|
|
-
|
|
|
|
|
|
+ def get_example(self, i):
|
|
|
|
+ im, parts, label = super(BBCropMixin, self).get_example(i)
|
|
if self.crop_to_bb:
|
|
if self.crop_to_bb:
|
|
x,y,w,h = self.bounding_box(i)
|
|
x,y,w,h = self.bounding_box(i)
|
|
im = im[y:y+h, x:x+w]
|
|
im = im[y:y+h, x:x+w]
|
|
parts[:, 1] -= x
|
|
parts[:, 1] -= x
|
|
parts[:, 2] -= y
|
|
parts[:, 2] -= y
|
|
|
|
+ return im, parts, label
|
|
|
|
+
|
|
|
|
+class UniformPartMixin(BaseMixin):
|
|
|
|
+
|
|
|
|
+ def __init__(self, uniform_parts=False, ratio=utils.DEFAULT_RATIO, *args, **kwargs):
|
|
|
|
+ super(UniformPartMixin, self).__init__(*args, **kwargs)
|
|
|
|
+ self.uniform_parts = uniform_parts
|
|
|
|
+ self.ratio = ratio
|
|
|
|
|
|
- h,w,c = im.shape
|
|
|
|
- # fit to the dimensions of the image
|
|
|
|
- parts[:, 1] = np.minimum(parts[:, 1], w - 1)
|
|
|
|
- parts[:, 2] = np.minimum(parts[:, 2], h - 1)
|
|
|
|
|
|
+ def get_example(self, i):
|
|
|
|
+ im, parts, label = super(UniformPartMixin, self).get_example(i)
|
|
|
|
+ if self.uniform_parts:
|
|
|
|
+ parts = utils.uniform_parts(im, ratio=self.ratio)
|
|
return im, parts, label
|
|
return im, parts, label
|
|
|
|
|
|
- __getitem__ = get_example
|
|
|
|
|
|
+class RandomBlackOutMixin(BaseMixin):
|
|
|
|
+
|
|
|
|
+ def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
|
|
|
|
+ super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
|
|
|
|
+ self.rnd = np.random.RandomState(seed)
|
|
|
|
+ self.rnd_select = rnd_select
|
|
|
|
+ self.n_parts = n_parts
|
|
|
|
+
|
|
|
|
+ def get_example(self, i):
|
|
|
|
+ im, parts, lab = super(RandomBlackOutMixin, self).get_example(i)
|
|
|
|
+ if self.rnd_select:
|
|
|
|
+ idxs, xy = utils.visible_part_locs(parts)
|
|
|
|
+ rnd_idxs = utils.random_idxs(idxs, rnd=self.rnd, n_parts=self.n_parts)
|
|
|
|
+
|
|
|
|
+ parts[:, -1] = 0
|
|
|
|
+ parts[rnd_idxs, -1] = 1
|
|
|
|
+
|
|
|
|
+ return im, parts, lab
|
|
|
|
+
|
|
|
|
+class Dataset(RandomBlackOutMixin, UniformPartMixin, BBCropMixin, ImageReadMixin):
|
|
|
|
+ pass
|