from abc import ABC from chainer_addons.dataset import AugmentationMixin from chainer_addons.dataset import PreprocessMixin from cvdatasets.dataset import AnnotationsReadMixin from cvdatasets.dataset import CroppedPartMixin from cvdatasets.dataset import IteratorMixin from finetune.dataset import _base_mixin class _parts_mixin(ABC): def get_example(self, i): im_obj = super(_parts_mixin, self).get_example(i) crops = im_obj.visible_crops(None) parts = crops + [im_obj.im_array] return parts, im_obj.label + self.label_shift class PartsDataset(_base_mixin, # augmentation and preprocessing AugmentationMixin, PreprocessMixin, _parts_mixin, # random uniform region selection CroppedPartMixin, # reads image AnnotationsReadMixin, IteratorMixin): def __init__(self, no_glob=False, *args, **kwargs): super(PartsDataset, self).__init__(*args, **kwargs) # mask = self.labels < 10 # self.uuids = self.uuids[mask] self.no_glob = no_glob def get_example(self, i): X, y = super(PartsDataset, self).get_example(i) X_parts, X_glob = X[:-1], X[-1] if self.no_glob: return X_parts, y else: return X_parts, X_glob, y