dataset.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import numpy as np
  2. import abc
  3. from chainer_addons.dataset import AugmentationMixin
  4. from chainer_addons.dataset import PreprocessMixin
  5. from cvdatasets.dataset import AnnotationsReadMixin
  6. from cvdatasets.dataset import RevealedPartMixin
  7. from cvdatasets.dataset import IteratorMixin
  8. class _pre_augmentation_mixin(abc.ABC):
  9. """ This mixin discards the parts from the ImageWrapper object
  10. and shifts the labels
  11. """
  12. label_shift = 1
  13. def get_example(self, i):
  14. im_obj = super(_pre_augmentation_mixin, self).get_example(i)
  15. im, parts, lab = im_obj.as_tuple()
  16. return im, lab + self.label_shift
  17. class _base_mixin(abc.ABC):
  18. """ This mixin converts images,that are in range
  19. [0..1] to the range [-1..1]
  20. """
  21. def get_example(self, i):
  22. im, lab = super(_base_mixin, self).get_example(i)
  23. if isinstance(im, list):
  24. im = np.array(im)
  25. if np.logical_and(0 <= im, im <= 1).all():
  26. im = im * 2 -1
  27. return im, lab
  28. class BaseDataset(_base_mixin,
  29. # augmentation and preprocessing
  30. AugmentationMixin, PreprocessMixin,
  31. _pre_augmentation_mixin,
  32. # random uniform region selection
  33. RevealedPartMixin,
  34. # reads image
  35. AnnotationsReadMixin,
  36. IteratorMixin):
  37. """Commonly used dataset constellation"""