dataset.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from abc import ABC
  2. from chainer_addons.dataset import AugmentationMixin
  3. from chainer_addons.dataset import PreprocessMixin
  4. from cvdatasets.dataset import AnnotationsReadMixin
  5. from cvdatasets.dataset import CroppedPartMixin
  6. from cvdatasets.dataset import IteratorMixin
  7. from finetune.dataset import _base_mixin
  8. class _parts_mixin(ABC):
  9. def get_example(self, i):
  10. im_obj = super(_parts_mixin, self).get_example(i)
  11. crops = im_obj.visible_crops(None)
  12. parts = crops + [im_obj.im_array]
  13. return parts, im_obj.label + self.label_shift
  14. class PartsDataset(_base_mixin,
  15. # augmentation and preprocessing
  16. AugmentationMixin, PreprocessMixin,
  17. _parts_mixin,
  18. # random uniform region selection
  19. CroppedPartMixin,
  20. # reads image
  21. AnnotationsReadMixin,
  22. IteratorMixin):
  23. def __init__(self, no_glob=False, *args, **kwargs):
  24. super(PartsDataset, self).__init__(*args, **kwargs)
  25. # mask = self.labels < 10
  26. # self.uuids = self.uuids[mask]
  27. self.no_glob = no_glob
  28. def get_example(self, i):
  29. X, y = super(PartsDataset, self).get_example(i)
  30. X_parts, X_glob = X[:-1], X[-1]
  31. if self.no_glob:
  32. return X_parts, y
  33. else:
  34. return X_parts, X_glob, y