|
@@ -0,0 +1,28 @@
|
|
|
+import numpy as np
|
|
|
+import logging
|
|
|
+
|
|
|
+try:
|
|
|
+ import chainer
|
|
|
+ from chainer.iterators import SerialIterator, MultiprocessIterator
|
|
|
+except ImportError:
|
|
|
+ has_chainer = False
|
|
|
+else:
|
|
|
+ has_chainer = True
|
|
|
+
|
|
|
+class IteratorMixin(object):
|
|
|
+ def new_iterator(self, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
|
|
|
+ assert has_chainer, "Please install chainer!"
|
|
|
+
|
|
|
+ if n_jobs > 0:
|
|
|
+ it = MultiprocessIterator(self,
|
|
|
+ n_processes=n_jobs,
|
|
|
+ n_prefetch=n_prefetch,
|
|
|
+ batch_size=batch_size,
|
|
|
+ repeat=repeat, shuffle=shuffle)
|
|
|
+ else:
|
|
|
+ it = SerialIterator(self,
|
|
|
+ batch_size=batch_size,
|
|
|
+ repeat=repeat, shuffle=shuffle)
|
|
|
+ logging.info("Using {it.__class__.__name__} with batch size {it.batch_size}".format(it=it))
|
|
|
+ n_batches = int(np.ceil(len(self) / it.batch_size))
|
|
|
+ return it, n_batches
|