Browse Source

added kwargs factory for dataset class

Dimitri Korsch 6 years ago
parent
commit
c7ae319bb2
1 changed files with 24 additions and 10 deletions
  1. 24 10
      cvfinetune/finetuner/base.py

+ 24 - 10
cvfinetune/finetuner/base.py

@@ -162,9 +162,10 @@ class _DatasetMixin(abc.ABC):
 		dataset and iterator creation.
 	"""
 
-	def __init__(self, dataset_cls, *args, **kwargs):
+	def __init__(self, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
 		super(_DatasetMixin, self).__init__(*args, **kwargs)
 		self.dataset_cls = dataset_cls
+		self.dataset_kwargs_factory = dataset_kwargs_factory
 
 	@property
 	def n_classes(self):
@@ -172,10 +173,15 @@ class _DatasetMixin(abc.ABC):
 
 	def new_dataset(self, opts, size, subset, augment):
 		"""Creates a dataset for a specific subset and certain options"""
-		kwargs = dict(
+		if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
+			kwargs = self.dataset_kwargs_factory(opts, subset, augment)
+		else:
+			kwargs = dict()
+
+		kwargs.update(dict(
 			subset=subset,
 			dataset_cls=self.dataset_cls,
-		)
+		))
 
 		# if opts.use_parts:
 		# 	kwargs.update(dict(
@@ -234,14 +240,22 @@ class _DatasetMixin(abc.ABC):
 	def init_iterators(self, opts):
 		"""Creates training and validation iterators from training and validation datasets"""
 
-		self.train_iter, _ = new_iterator(self.train_data,
-			opts.n_jobs, opts.batch_size
-		)
+		kwargs = dict(n_jobs=opts.n_jobs, batch_size=opts.batch_size)
+
+		if hasattr(self.train_data, "new_iterator"):
+			self.train_iter, _ = self.train_data.new_iterator(**kwargs)
+		else:
+			self.train_iter, _ = new_iterator(self.train_data, **kwargs)
+
+		if hasattr(self.val_data, "new_iterator"):
+			self.val_iter, _ = self.val_data.new_iterator(**kwargs,
+				repeat=False, shuffle=False
+			)
+		else:
+			self.val_iter, _ = new_iterator(self.val_data,
+				**kwargs, repeat=False, shuffle=False
+			)
 
-		self.val_iter, _ = new_iterator(self.val_data,
-			opts.n_jobs, opts.batch_size,
-			repeat=False, shuffle=False
-		)
 
 class _TrainerMixin(abc.ABC):
 	"""This mixin is responsible for updater, evaluator and trainer creation.