瀏覽代碼

minor changes in device initialization

Dimitri Korsch 5 年之前
父節點
當前提交
2b1f830114
共有 1 個文件被更改,包括 6 次插入4 次删除
  1. 6 4
      cvfinetune/finetuner/base.py

+ 6 - 4
cvfinetune/finetuner/base.py

@@ -26,8 +26,6 @@ class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._Trainer
 		super(DefaultFinetuner, self).__init__(opts=opts, *args, **kwargs)
 
 		self.gpu_config(opts)
-		cuda.get_device_from_id(self.device).use()
-
 		self.read_annotations(opts)
 
 		self.init_model(opts)
@@ -43,7 +41,11 @@ class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._Trainer
 
 	def gpu_config(self, opts):
 		if -1 in opts.gpu:
-			self.device = -1
+			self.device_id = -1
 		else:
-			self.device = opts.gpu[0]
+			self.device_id = opts.gpu[0]
+
+		self.device = cuda.get_device_from_id(self.device_id)
+		self.device.use()
+		return self.device