Browse Source

added IteratorMixin in order to create chainer Iterators

Dimitri Korsch 6 years ago
parent
commit
ce1c0d36c7
2 changed files with 30 additions and 1 deletions
  1. 2 1
      nabirds/dataset/__init__.py
  2. 28 0
      nabirds/dataset/mixins/chainer_mixins.py

+ 2 - 1
nabirds/dataset/__init__.py

@@ -1,9 +1,10 @@
 from .mixins.reading import AnnotationsReadMixin, ImageListReadingMixin
 from .mixins.parts import PartMixin, RevealedPartMixin, CroppedPartMixin
 from .mixins.features import PreExtractedFeaturesMixin
+from .mixins.chainer_mixins import IteratorMixin
 
 
-class ImageWrapperDataset(PartMixin, PreExtractedFeaturesMixin, AnnotationsReadMixin):
+class ImageWrapperDataset(PartMixin, PreExtractedFeaturesMixin, AnnotationsReadMixin, IteratorMixin):
 	pass
 
 class Dataset(ImageWrapperDataset):

+ 28 - 0
nabirds/dataset/mixins/chainer_mixins.py

@@ -0,0 +1,28 @@
+import numpy as np
+import logging
+
+try:
+	import chainer
+	from chainer.iterators import SerialIterator, MultiprocessIterator
+except ImportError:
+	has_chainer = False
+else:
+	has_chainer = True
+
+class IteratorMixin(object):
+	def new_iterator(self, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
+		assert has_chainer, "Please install chainer!"
+
+		if n_jobs > 0:
+			it = MultiprocessIterator(self,
+				n_processes=n_jobs,
+				n_prefetch=n_prefetch,
+				batch_size=batch_size,
+				repeat=repeat, shuffle=shuffle)
+		else:
+			it = SerialIterator(self,
+				batch_size=batch_size,
+				repeat=repeat, shuffle=shuffle)
+		logging.info("Using {it.__class__.__name__} with batch size {it.batch_size}".format(it=it))
+		n_batches = int(np.ceil(len(self) / it.batch_size))
+		return it, n_batches