import numpy as np from abc import ABC, abstractmethod from imageio import imread from . import utils class BaseMixin(ABC): @abstractmethod def get_example(self, i): pass def __getitem__(self, i): return self.get_example(i) class ImageReadMixin(BaseMixin): def __init__(self, uuids, annotations, mode="RGB"): super(BaseMixin, self).__init__() self.uuids = uuids self._annot = annotations self.mode = mode def __len__(self): return len(self.uuids) def _get(self, method, i): return getattr(self._annot, method)(self.uuids[i]) def get_example(self, i): res = super(ImageReadMixin, self).get_example(i) # if the super class returns something, then the class inheritance is wrong assert res is None, "ImageReadMixin should be the last class in the hierarchy!" methods = ["image", "parts", "label"] im_path, parts, label = [self._get(m, i) for m in methods] im = imread(im_path, pilmode=self.mode) return im, parts, label class BBCropMixin(BaseMixin): def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs): super(BBCropMixin, self).__init__(*args, **kwargs) self.crop_to_bb = crop_to_bb self.crop_uniform = crop_uniform def bounding_box(self, i): bbox = self._get("bounding_box", i) x,y,w,h = [bbox[attr] for attr in "xywh"] if self.crop_uniform: x0 = x + w//2 y0 = y + h//2 crop_size = max(w//2, h//2) x,y = max(x0 - crop_size, 0), max(y0 - crop_size, 0) w = h = crop_size * 2 return x,y,w,h def get_example(self, i): im, parts, label = super(BBCropMixin, self).get_example(i) if self.crop_to_bb: x,y,w,h = self.bounding_box(i) im = im[y:y+h, x:x+w] parts[:, 1] -= x parts[:, 2] -= y return im, parts, label class UniformPartMixin(BaseMixin): def __init__(self, uniform_parts=False, ratio=utils.DEFAULT_RATIO, *args, **kwargs): super(UniformPartMixin, self).__init__(*args, **kwargs) self.uniform_parts = uniform_parts self.ratio = ratio def get_example(self, i): im, parts, label = super(UniformPartMixin, self).get_example(i) if self.uniform_parts: parts = utils.uniform_parts(im, ratio=self.ratio) return im, parts, label class RandomBlackOutMixin(BaseMixin): def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs): super(RandomBlackOutMixin, self).__init__(*args, **kwargs) self.rnd = np.random.RandomState(seed) self.rnd_select = rnd_select self.n_parts = n_parts def get_example(self, i): im, parts, lab = super(RandomBlackOutMixin, self).get_example(i) if self.rnd_select: idxs, xy = utils.visible_part_locs(parts) rnd_idxs = utils.random_idxs(idxs, rnd=self.rnd, n_parts=self.n_parts) parts[:, -1] = 0 parts[rnd_idxs, -1] = 1 return im, parts, lab class Dataset(RandomBlackOutMixin, UniformPartMixin, BBCropMixin, ImageReadMixin): pass