Kaynağa Gözat

refactored base and MPI finetuner a bit

Dimitri Korsch 3 yıl önce
ebeveyn
işleme
0d825959cb
2 değiştirilmiş dosya ile 17 ekleme ve 14 silme
  1. 3 3
      cvfinetune/finetuner/base.py
  2. 14 11
      cvfinetune/finetuner/mpi.py

+ 3 - 3
cvfinetune/finetuner/base.py

@@ -1,6 +1,5 @@
 import chainer
-
-from chainer.backends import cuda
+import logging
 
 from cvfinetune.finetuner import mixins
 
@@ -28,7 +27,7 @@ class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._Trainer
 		self.init_evaluator()
 
 	def init_device(self):
-		self.device = cuda.get_device_from_id(self.device_id)
+		self.device = chainer.get_device(self.device_id)
 		self.device.use()
 		return self.device
 
@@ -39,5 +38,6 @@ class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._Trainer
 		else:
 			self.device_id = opts.gpu[0]
 
+		logging.info(f"Using device {device}")
 		return self.init_device()
 

+ 14 - 11
cvfinetune/finetuner/mpi.py

@@ -1,5 +1,7 @@
 import chainermn
 import logging
+import warnings
+
 from chainermn import scatter_dataset as scatter
 
 from cvfinetune.finetuner.base import DefaultFinetuner
@@ -19,20 +21,21 @@ class MPIFinetuner(DefaultFinetuner):
 		return not (self.comm is not None and self.comm.rank != 0)
 
 	def gpu_config(self, opts):
-		device = super(MPIFinetuner, self).gpu_config(opts)
 
-		if self.mpi:
-			if len(opts.gpu) > 1:
-				self.device_id = opts.gpu[self.comm.rank]
-			else:
-				self.device_id += self.comm.intra_rank
-
-			device = self.init_device()
-			ranks = f"{self.comm.rank}|{self.comm.intra_rank}|{self.comm.inter_rank}"
-			logging.info(f"Node with ranks {ranks} assigned to {device}")
+		if not self.mpi:
+			msg = "Using MPIFinetuner without setting a communicator!"
+			warnings.warn(msg)
+			logging.warn(msg)
+			return super(MPIFinetuner, self).gpu_config(opts)
+
+		if len(opts.gpu) > 1:
+			self.device_id = opts.gpu[self.comm.rank]
 		else:
-			logging.warn("Using MPIFinetuner without setting a communicator!")
+			self.device_id += self.comm.intra_rank
 
+		device = self.init_device()
+		ranks = f"{self.comm.rank} | {self.comm.intra_rank} | {self.comm.inter_rank}"
+		logging.info(f"Node with ranks {ranks} assigned to {device}")
 		return device