dataset.py 613 B

12345678910111213141516171819
  1. import logging
  2. import numpy as np
  3. def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
  4. from chainer.iterators import SerialIterator, MultiprocessIterator
  5. if n_jobs > 0:
  6. it = MultiprocessIterator(data,
  7. n_processes=n_jobs,
  8. n_prefetch=n_prefetch,
  9. batch_size=batch_size,
  10. repeat=repeat, shuffle=shuffle)
  11. else:
  12. it = SerialIterator(data,
  13. batch_size=batch_size,
  14. repeat=repeat, shuffle=shuffle)
  15. logging.info("Using {it.__class__.__name__} with batch size {it.batch_size}".format(it=it))
  16. n_batches = int(np.ceil(len(data) / it.batch_size))
  17. return it, n_batches