12345678910111213141516171819202122232425262728 |
- import numpy as np
- import logging
- try:
- import chainer
- from chainer.iterators import SerialIterator, MultiprocessIterator
- except ImportError:
- has_chainer = False
- else:
- has_chainer = True
- class IteratorMixin(object):
- def new_iterator(self, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
- assert has_chainer, "Please install chainer!"
- if n_jobs > 0:
- it = MultiprocessIterator(self,
- n_processes=n_jobs,
- n_prefetch=n_prefetch,
- batch_size=batch_size,
- repeat=repeat, shuffle=shuffle)
- else:
- it = SerialIterator(self,
- batch_size=batch_size,
- repeat=repeat, shuffle=shuffle)
- logging.info("Using {it.__class__.__name__} with batch size {it.batch_size}".format(it=it))
- n_batches = int(np.ceil(len(self) / it.batch_size))
- return it, n_batches
|