@@ -0,0 +1 @@
+from .iterator_mixin import IteratorMixin
@@ -0,0 +1,14 @@
+try:
+ import chainer
+except ImportError:
+ has_chainer = False
+else:
+ has_chainer = True
+
+from abc import ABC
+class BaseChainerMixin(ABC):
+ def chainer_check(self):
+ global has_chainer
+ assert has_chainer, "Please install chainer!"
@@ -1,17 +1,12 @@
import numpy as np
import logging
-try:
- import chainer
- from chainer.iterators import SerialIterator, MultiprocessIterator
-except ImportError:
- has_chainer = False
-else:
- has_chainer = True
+from .base import BaseChainerMixin
-class IteratorMixin(object):
+class IteratorMixin(BaseChainerMixin):
def new_iterator(self, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
- assert has_chainer, "Please install chainer!"
+ self.chainer_check()
+ from chainer.iterators import SerialIterator, MultiprocessIterator
if n_jobs > 0:
it = MultiprocessIterator(self,