Explorar o código

refactored base and MPI finetuner a bit

Dimitri Korsch %!s(int64=3) %!d(string=hai) anos
pai
achega
0d825959cb
Modificáronse 2 ficheiros con 17 adicións e 14 borrados
  1. 3 3
      cvfinetune/finetuner/base.py
  2. 14 11
      cvfinetune/finetuner/mpi.py

+ 3 - 3
cvfinetune/finetuner/base.py

@@ -1,6 +1,5 @@
 import chainer
 import chainer
-
-from chainer.backends import cuda
+import logging
 
 
 from cvfinetune.finetuner import mixins
 from cvfinetune.finetuner import mixins
 
 
@@ -28,7 +27,7 @@ class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._Trainer
 		self.init_evaluator()
 		self.init_evaluator()
 
 
 	def init_device(self):
 	def init_device(self):
-		self.device = cuda.get_device_from_id(self.device_id)
+		self.device = chainer.get_device(self.device_id)
 		self.device.use()
 		self.device.use()
 		return self.device
 		return self.device
 
 
@@ -39,5 +38,6 @@ class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._Trainer
 		else:
 		else:
 			self.device_id = opts.gpu[0]
 			self.device_id = opts.gpu[0]
 
 
+		logging.info(f"Using device {device}")
 		return self.init_device()
 		return self.init_device()
 
 

+ 14 - 11
cvfinetune/finetuner/mpi.py

@@ -1,5 +1,7 @@
 import chainermn
 import chainermn
 import logging
 import logging
+import warnings
+
 from chainermn import scatter_dataset as scatter
 from chainermn import scatter_dataset as scatter
 
 
 from cvfinetune.finetuner.base import DefaultFinetuner
 from cvfinetune.finetuner.base import DefaultFinetuner
@@ -19,20 +21,21 @@ class MPIFinetuner(DefaultFinetuner):
 		return not (self.comm is not None and self.comm.rank != 0)
 		return not (self.comm is not None and self.comm.rank != 0)
 
 
 	def gpu_config(self, opts):
 	def gpu_config(self, opts):
-		device = super(MPIFinetuner, self).gpu_config(opts)
 
 
-		if self.mpi:
-			if len(opts.gpu) > 1:
-				self.device_id = opts.gpu[self.comm.rank]
-			else:
-				self.device_id += self.comm.intra_rank
-
-			device = self.init_device()
-			ranks = f"{self.comm.rank}|{self.comm.intra_rank}|{self.comm.inter_rank}"
-			logging.info(f"Node with ranks {ranks} assigned to {device}")
+		if not self.mpi:
+			msg = "Using MPIFinetuner without setting a communicator!"
+			warnings.warn(msg)
+			logging.warn(msg)
+			return super(MPIFinetuner, self).gpu_config(opts)
+
+		if len(opts.gpu) > 1:
+			self.device_id = opts.gpu[self.comm.rank]
 		else:
 		else:
-			logging.warn("Using MPIFinetuner without setting a communicator!")
+			self.device_id += self.comm.intra_rank
 
 
+		device = self.init_device()
+		ranks = f"{self.comm.rank} | {self.comm.intra_rank} | {self.comm.inter_rank}"
+		logging.info(f"Node with ranks {ranks} assigned to {device}")
 		return device
 		return device