Procházet zdrojové kódy

minor fix for MPI support

Dimitri Korsch před 4 roky
rodič
revize
7fa5b00e7a
2 změnil soubory, kde provedl 13 přidání a 5 odebrání
  1. 7 3
      cvfinetune/finetuner/base.py
  2. 6 2
      cvfinetune/finetuner/mpi.py

+ 7 - 3
cvfinetune/finetuner/base.py

@@ -27,13 +27,17 @@ class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._Trainer
 		self.init_updater()
 		self.init_evaluator()
 
+	def init_device(self):
+		self.device = cuda.get_device_from_id(self.device_id)
+		self.device.use()
+		return self.device
+
+
 	def gpu_config(self, opts):
 		if -1 in opts.gpu:
 			self.device_id = -1
 		else:
 			self.device_id = opts.gpu[0]
 
-		self.device = cuda.get_device_from_id(self.device_id)
-		self.device.use()
-		return self.device
+		return self.init_device()
 

+ 6 - 2
cvfinetune/finetuner/mpi.py

@@ -23,14 +23,18 @@ class MPIFinetuner(DefaultFinetuner):
 
 		if self.mpi:
 			if len(opts.gpu) > 1:
-				self.device = opts.gpu[self.comm.rank]
+				self.device_id = opts.gpu[self.comm.rank]
 			else:
-				self.device += self.comm.intra_rank
+				self.device_id += self.comm.intra_rank
+
 			ranks = f"{self.comm.rank}|{self.comm.intra_rank}|{self.comm.inter_rank}"
 			logging.info(f"Node with ranks {ranks} assigned to GPU #{self.device}")
 		else:
 			logging.warn("Using MPIFinetuner without setting a communicator!")
 
+		return self.init_device()
+
+
 	def scatter_datasets(self):
 		if self.mpi:
 			self.train_data = scatter(self.train_data, self.comm)