|
@@ -1,17 +1,12 @@
|
|
|
import numpy as np
|
|
|
import logging
|
|
|
|
|
|
-try:
|
|
|
- import chainer
|
|
|
- from chainer.iterators import SerialIterator, MultiprocessIterator
|
|
|
-except ImportError:
|
|
|
- has_chainer = False
|
|
|
-else:
|
|
|
- has_chainer = True
|
|
|
+from .base import BaseChainerMixin
|
|
|
|
|
|
-class IteratorMixin(object):
|
|
|
+class IteratorMixin(BaseChainerMixin):
|
|
|
def new_iterator(self, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
|
|
|
- assert has_chainer, "Please install chainer!"
|
|
|
+ self.chainer_check()
|
|
|
+ from chainer.iterators import SerialIterator, MultiprocessIterator
|
|
|
|
|
|
if n_jobs > 0:
|
|
|
it = MultiprocessIterator(self,
|