瀏覽代碼

added constructor argument to use either threads or processes for multi-job loading

Dimitri Korsch 2 年之前
父節點
當前提交
45465ff6b7
共有 1 個文件被更改,包括 17 次插入11 次删除
  1. 17 11
      cvfinetune/finetuner/mixins/iterator.py

+ 17 - 11
cvfinetune/finetuner/mixins/iterator.py

@@ -12,18 +12,20 @@ class _IteratorMixin(BaseMixin):
                  *args,
                  *args,
                  batch_size: int = 32,
                  batch_size: int = 32,
                  n_jobs: int = 1,
                  n_jobs: int = 1,
+                 use_threads: bool = False,
                  **kwargs):
                  **kwargs):
-    	super().__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
 
 
-    	self._batch_size = batch_size
-    	self._n_jobs = n_jobs
+        self._batch_size = batch_size
+        self._n_jobs = n_jobs
+        self._use_threads = use_threads
 
 
 
 
     def new_iterator(self, ds, **kwargs):
     def new_iterator(self, ds, **kwargs):
-    	if hasattr(ds, "new_iterator"):
-    		return ds.new_iterator(**kwargs)
-    	else:
-    		return new_iterator(ds, **kwargs)
+        if hasattr(ds, "new_iterator"):
+            return ds.new_iterator(**kwargs)
+        else:
+            return new_iterator(ds, **kwargs)
 
 
     def init_iterators(self):
     def init_iterators(self):
         """Creates training and validation iterators from training and validation datasets"""
         """Creates training and validation iterators from training and validation datasets"""
@@ -31,11 +33,15 @@ class _IteratorMixin(BaseMixin):
         self._check_attr("val_data")
         self._check_attr("val_data")
         self._check_attr("train_data")
         self._check_attr("train_data")
 
 
-        kwargs = dict(n_jobs=self._n_jobs, batch_size=self._batch_size)
+        kwargs = dict(
+            n_jobs=self._n_jobs,
+            batch_size=self._batch_size,
+            use_threads=self._use_threads,
+        )
 
 
         self.train_iter, _ = self.new_iterator(self.train_data,
         self.train_iter, _ = self.new_iterator(self.train_data,
-        	                                   **kwargs)
+                                               **kwargs)
 
 
         self.val_iter, _ = self.new_iterator(self.val_data,
         self.val_iter, _ = self.new_iterator(self.val_data,
-        	                                 repeat=False, shuffle=False,
-        	                                 **kwargs)
+                                             repeat=False, shuffle=False,
+                                             **kwargs)