chainer_mixins.py 793 B

12345678910111213141516171819202122232425262728
  1. import numpy as np
  2. import logging
  3. try:
  4. import chainer
  5. from chainer.iterators import SerialIterator, MultiprocessIterator
  6. except ImportError:
  7. has_chainer = False
  8. else:
  9. has_chainer = True
  10. class IteratorMixin(object):
  11. def new_iterator(self, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
  12. assert has_chainer, "Please install chainer!"
  13. if n_jobs > 0:
  14. it = MultiprocessIterator(self,
  15. n_processes=n_jobs,
  16. n_prefetch=n_prefetch,
  17. batch_size=batch_size,
  18. repeat=repeat, shuffle=shuffle)
  19. else:
  20. it = SerialIterator(self,
  21. batch_size=batch_size,
  22. repeat=repeat, shuffle=shuffle)
  23. logging.info("Using {it.__class__.__name__} with batch size {it.batch_size}".format(it=it))
  24. n_batches = int(np.ceil(len(self) / it.batch_size))
  25. return it, n_batches