瀏覽代碼

refactored chainer mixins

Dimitri Korsch 6 年之前
父節點
當前提交
d557e7c317

+ 1 - 0
nabirds/dataset/mixins/chainer_mixins/__init__.py

@@ -0,0 +1 @@
+from .iterator_mixin import IteratorMixin

+ 14 - 0
nabirds/dataset/mixins/chainer_mixins/base.py

@@ -0,0 +1,14 @@
+try:
+	import chainer
+except ImportError:
+	has_chainer = False
+else:
+	has_chainer = True
+
+from abc import ABC
+
+class BaseChainerMixin(ABC):
+
+	def chainer_check(self):
+		global has_chainer
+		assert has_chainer, "Please install chainer!"

+ 4 - 9
nabirds/dataset/mixins/chainer_mixins.py → nabirds/dataset/mixins/chainer_mixins/iterator_mixin.py

@@ -1,17 +1,12 @@
 import numpy as np
 import logging
 
-try:
-	import chainer
-	from chainer.iterators import SerialIterator, MultiprocessIterator
-except ImportError:
-	has_chainer = False
-else:
-	has_chainer = True
+from .base import BaseChainerMixin
 
-class IteratorMixin(object):
+class IteratorMixin(BaseChainerMixin):
 	def new_iterator(self, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
-		assert has_chainer, "Please install chainer!"
+		self.chainer_check()
+		from chainer.iterators import SerialIterator, MultiprocessIterator
 
 		if n_jobs > 0:
 			it = MultiprocessIterator(self,