|
|
@@ -162,9 +162,10 @@ class _DatasetMixin(abc.ABC):
|
|
|
dataset and iterator creation.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, dataset_cls, *args, **kwargs):
|
|
|
+ def __init__(self, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
|
|
|
super(_DatasetMixin, self).__init__(*args, **kwargs)
|
|
|
self.dataset_cls = dataset_cls
|
|
|
+ self.dataset_kwargs_factory = dataset_kwargs_factory
|
|
|
|
|
|
@property
|
|
|
def n_classes(self):
|
|
|
@@ -172,10 +173,15 @@ class _DatasetMixin(abc.ABC):
|
|
|
|
|
|
def new_dataset(self, opts, size, subset, augment):
|
|
|
"""Creates a dataset for a specific subset and certain options"""
|
|
|
- kwargs = dict(
|
|
|
+ if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
|
|
|
+ kwargs = self.dataset_kwargs_factory(opts, subset, augment)
|
|
|
+ else:
|
|
|
+ kwargs = dict()
|
|
|
+
|
|
|
+ kwargs.update(dict(
|
|
|
subset=subset,
|
|
|
dataset_cls=self.dataset_cls,
|
|
|
- )
|
|
|
+ ))
|
|
|
|
|
|
# if opts.use_parts:
|
|
|
# kwargs.update(dict(
|
|
|
@@ -234,14 +240,22 @@ class _DatasetMixin(abc.ABC):
|
|
|
def init_iterators(self, opts):
|
|
|
"""Creates training and validation iterators from training and validation datasets"""
|
|
|
|
|
|
- self.train_iter, _ = new_iterator(self.train_data,
|
|
|
- opts.n_jobs, opts.batch_size
|
|
|
- )
|
|
|
+ kwargs = dict(n_jobs=opts.n_jobs, batch_size=opts.batch_size)
|
|
|
+
|
|
|
+ if hasattr(self.train_data, "new_iterator"):
|
|
|
+ self.train_iter, _ = self.train_data.new_iterator(**kwargs)
|
|
|
+ else:
|
|
|
+ self.train_iter, _ = new_iterator(self.train_data, **kwargs)
|
|
|
+
|
|
|
+ if hasattr(self.val_data, "new_iterator"):
|
|
|
+ self.val_iter, _ = self.val_data.new_iterator(**kwargs,
|
|
|
+ repeat=False, shuffle=False
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.val_iter, _ = new_iterator(self.val_data,
|
|
|
+ **kwargs, repeat=False, shuffle=False
|
|
|
+ )
|
|
|
|
|
|
- self.val_iter, _ = new_iterator(self.val_data,
|
|
|
- opts.n_jobs, opts.batch_size,
|
|
|
- repeat=False, shuffle=False
|
|
|
- )
|
|
|
|
|
|
class _TrainerMixin(abc.ABC):
|
|
|
"""This mixin is responsible for updater, evaluator and trainer creation.
|