dataset.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from abc import ABC
  2. from chainer_addons.dataset import AugmentationMixin
  3. from chainer_addons.dataset import PreprocessMixin
  4. from cvdatasets.dataset import AnnotationsReadMixin
  5. from cvdatasets.dataset import CroppedPartMixin
  6. from cvdatasets.dataset import IteratorMixin
  7. from finetune.dataset import _base_mixin
  8. class _parts_mixin(ABC):
  9. def get_example(self, i):
  10. im_obj = super(_parts_mixin, self).get_example(i)
  11. crops = im_obj.visible_crops(None)
  12. parts = crops + [im_obj.im_array]
  13. return parts, im_obj.label + self.label_shift
  14. class PartsDataset(_base_mixin,
  15. # augmentation and preprocessing
  16. AugmentationMixin, PreprocessMixin,
  17. _parts_mixin,
  18. # random uniform region selection
  19. CroppedPartMixin,
  20. # reads image
  21. AnnotationsReadMixin,
  22. IteratorMixin):
  23. def __init__(self, no_glob=False, *args, **kwargs):
  24. super(PartsDataset, self).__init__(*args, **kwargs)
  25. # mask = self.labels < 10
  26. # self.uuids = self.uuids[mask]
  27. self.no_glob = no_glob
  28. def get_example(self, i):
  29. X, y = super(PartsDataset, self).get_example(i)
  30. X_parts, X_glob = X[:-1], X[-1]
  31. if self.no_glob:
  32. return X_parts, y
  33. else:
  34. return X_parts, X_glob, y