|
@@ -12,18 +12,20 @@ class _IteratorMixin(BaseMixin):
|
|
*args,
|
|
*args,
|
|
batch_size: int = 32,
|
|
batch_size: int = 32,
|
|
n_jobs: int = 1,
|
|
n_jobs: int = 1,
|
|
|
|
+ use_threads: bool = False,
|
|
**kwargs):
|
|
**kwargs):
|
|
- super().__init__(*args, **kwargs)
|
|
|
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
|
|
|
- self._batch_size = batch_size
|
|
|
|
- self._n_jobs = n_jobs
|
|
|
|
|
|
+ self._batch_size = batch_size
|
|
|
|
+ self._n_jobs = n_jobs
|
|
|
|
+ self._use_threads = use_threads
|
|
|
|
|
|
|
|
|
|
def new_iterator(self, ds, **kwargs):
|
|
def new_iterator(self, ds, **kwargs):
|
|
- if hasattr(ds, "new_iterator"):
|
|
|
|
- return ds.new_iterator(**kwargs)
|
|
|
|
- else:
|
|
|
|
- return new_iterator(ds, **kwargs)
|
|
|
|
|
|
+ if hasattr(ds, "new_iterator"):
|
|
|
|
+ return ds.new_iterator(**kwargs)
|
|
|
|
+ else:
|
|
|
|
+ return new_iterator(ds, **kwargs)
|
|
|
|
|
|
def init_iterators(self):
|
|
def init_iterators(self):
|
|
"""Creates training and validation iterators from training and validation datasets"""
|
|
"""Creates training and validation iterators from training and validation datasets"""
|
|
@@ -31,11 +33,15 @@ class _IteratorMixin(BaseMixin):
|
|
self._check_attr("val_data")
|
|
self._check_attr("val_data")
|
|
self._check_attr("train_data")
|
|
self._check_attr("train_data")
|
|
|
|
|
|
- kwargs = dict(n_jobs=self._n_jobs, batch_size=self._batch_size)
|
|
|
|
|
|
+ kwargs = dict(
|
|
|
|
+ n_jobs=self._n_jobs,
|
|
|
|
+ batch_size=self._batch_size,
|
|
|
|
+ use_threads=self._use_threads,
|
|
|
|
+ )
|
|
|
|
|
|
self.train_iter, _ = self.new_iterator(self.train_data,
|
|
self.train_iter, _ = self.new_iterator(self.train_data,
|
|
- **kwargs)
|
|
|
|
|
|
+ **kwargs)
|
|
|
|
|
|
self.val_iter, _ = self.new_iterator(self.val_data,
|
|
self.val_iter, _ = self.new_iterator(self.val_data,
|
|
- repeat=False, shuffle=False,
|
|
|
|
- **kwargs)
|
|
|
|
|
|
+ repeat=False, shuffle=False,
|
|
|
|
+ **kwargs)
|