123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- import numpy as np
- from abc import ABC, abstractmethod
- from imageio import imread
- from . import utils
- class BaseMixin(ABC):
- @abstractmethod
- def get_example(self, i):
- pass
- def __getitem__(self, i):
- return self.get_example(i)
- class ImageReadMixin(BaseMixin):
- def __init__(self, uuids, annotations, mode="RGB"):
- super(BaseMixin, self).__init__()
- self.uuids = uuids
- self._annot = annotations
- self.mode = mode
- def __len__(self):
- return len(self.uuids)
- def _get(self, method, i):
- return getattr(self._annot, method)(self.uuids[i])
- def get_example(self, i):
- res = super(ImageReadMixin, self).get_example(i)
- # if the super class returns something, then the class inheritance is wrong
- assert res is None, "ImageReadMixin should be the last class in the hierarchy!"
- methods = ["image", "parts", "label"]
- im_path, parts, label = [self._get(m, i) for m in methods]
- im = imread(im_path, pilmode=self.mode)
- return im, parts, label
- class BBCropMixin(BaseMixin):
- def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs):
- super(BBCropMixin, self).__init__(*args, **kwargs)
- self.crop_to_bb = crop_to_bb
- self.crop_uniform = crop_uniform
- def bounding_box(self, i):
- bbox = self._get("bounding_box", i)
- x,y,w,h = [bbox[attr] for attr in "xywh"]
- if self.crop_uniform:
- x0 = x + w//2
- y0 = y + h//2
- crop_size = max(w//2, h//2)
- x,y = max(x0 - crop_size, 0), max(y0 - crop_size, 0)
- w = h = crop_size * 2
- return x,y,w,h
- def get_example(self, i):
- im, parts, label = super(BBCropMixin, self).get_example(i)
- if self.crop_to_bb:
- x,y,w,h = self.bounding_box(i)
- im = im[y:y+h, x:x+w]
- parts[:, 1] -= x
- parts[:, 2] -= y
- return im, parts, label
- class UniformPartMixin(BaseMixin):
- def __init__(self, uniform_parts=False, ratio=utils.DEFAULT_RATIO, *args, **kwargs):
- super(UniformPartMixin, self).__init__(*args, **kwargs)
- self.uniform_parts = uniform_parts
- self.ratio = ratio
- def get_example(self, i):
- im, parts, label = super(UniformPartMixin, self).get_example(i)
- if self.uniform_parts:
- parts = utils.uniform_parts(im, ratio=self.ratio)
- return im, parts, label
- class RandomBlackOutMixin(BaseMixin):
- def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
- super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
- self.rnd = np.random.RandomState(seed)
- self.rnd_select = rnd_select
- self.n_parts = n_parts
- def get_example(self, i):
- im, parts, lab = super(RandomBlackOutMixin, self).get_example(i)
- if self.rnd_select:
- idxs, xy = utils.visible_part_locs(parts)
- rnd_idxs = utils.random_idxs(idxs, rnd=self.rnd, n_parts=self.n_parts)
- parts[:, -1] = 0
- parts[rnd_idxs, -1] = 1
- return im, parts, lab
- class Dataset(RandomBlackOutMixin, UniformPartMixin, BBCropMixin, ImageReadMixin):
- pass
|