import numpy as np import logging try: import chainer from chainer.iterators import SerialIterator, MultiprocessIterator except ImportError: has_chainer = False else: has_chainer = True class IteratorMixin(object): def new_iterator(self, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2): assert has_chainer, "Please install chainer!" if n_jobs > 0: it = MultiprocessIterator(self, n_processes=n_jobs, n_prefetch=n_prefetch, batch_size=batch_size, repeat=repeat, shuffle=shuffle) else: it = SerialIterator(self, batch_size=batch_size, repeat=repeat, shuffle=shuffle) logging.info("Using {it.__class__.__name__} with batch size {it.batch_size}".format(it=it)) n_batches = int(np.ceil(len(self) / it.batch_size)) return it, n_batches