Browse Source

moved function for new iterator creation

Dimitri Korsch 6 years ago
parent
commit
e6b5671c36

+ 3 - 16
nabirds/dataset/mixins/chainer_mixins/iterator_mixin.py

@@ -2,22 +2,9 @@ import numpy as np
 import logging
 
 from .base import BaseChainerMixin
+from nabirds.utils import new_iterator
 
 class IteratorMixin(BaseChainerMixin):
-	def new_iterator(self, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
+	def new_iterator(self, **kwargs):
 		self.chainer_check()
-		from chainer.iterators import SerialIterator, MultiprocessIterator
-
-		if n_jobs > 0:
-			it = MultiprocessIterator(self,
-				n_processes=n_jobs,
-				n_prefetch=n_prefetch,
-				batch_size=batch_size,
-				repeat=repeat, shuffle=shuffle)
-		else:
-			it = SerialIterator(self,
-				batch_size=batch_size,
-				repeat=repeat, shuffle=shuffle)
-		logging.info("Using {it.__class__.__name__} with batch size {it.batch_size}".format(it=it))
-		n_batches = int(np.ceil(len(self) / it.batch_size))
-		return it, n_batches
+		return new_iterator(data=self, **kwargs)

+ 1 - 0
nabirds/utils/__init__.py

@@ -60,3 +60,4 @@ class _MetaInfo(object):
 
 
 from .image import asarray, dimensions
+from .dataset import new_iterator

+ 19 - 0
nabirds/utils/dataset.py

@@ -0,0 +1,19 @@
+import logging
+import numpy as np
+
+def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
+	from chainer.iterators import SerialIterator, MultiprocessIterator
+
+	if n_jobs > 0:
+		it = MultiprocessIterator(data,
+			n_processes=n_jobs,
+			n_prefetch=n_prefetch,
+			batch_size=batch_size,
+			repeat=repeat, shuffle=shuffle)
+	else:
+		it = SerialIterator(data,
+			batch_size=batch_size,
+			repeat=repeat, shuffle=shuffle)
+	logging.info("Using {it.__class__.__name__} with batch size {it.batch_size}".format(it=it))
+	n_batches = int(np.ceil(len(data) / it.batch_size))
+	return it, n_batches