iterator_mixin.py 726 B

1234567891011121314151617181920212223
  1. import numpy as np
  2. import logging
  3. from .base import BaseChainerMixin
  4. class IteratorMixin(BaseChainerMixin):
  5. def new_iterator(self, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
  6. self.chainer_check()
  7. from chainer.iterators import SerialIterator, MultiprocessIterator
  8. if n_jobs > 0:
  9. it = MultiprocessIterator(self,
  10. n_processes=n_jobs,
  11. n_prefetch=n_prefetch,
  12. batch_size=batch_size,
  13. repeat=repeat, shuffle=shuffle)
  14. else:
  15. it = SerialIterator(self,
  16. batch_size=batch_size,
  17. repeat=repeat, shuffle=shuffle)
  18. logging.info("Using {it.__class__.__name__} with batch size {it.batch_size}".format(it=it))
  19. n_batches = int(np.ceil(len(self) / it.batch_size))
  20. return it, n_batches