Browse Source

fixed MPIFinetuner

Dimitri Korsch 3 years ago
parent
commit
c8f8e42105
1 changed files with 24 additions and 19 deletions
  1. 24 19
      cvfinetune/finetuner/mpi.py

+ 24 - 19
cvfinetune/finetuner/mpi.py

@@ -12,31 +12,32 @@ class MPIFinetuner(DefaultFinetuner):
 		self.comm = comm
 		super(MPIFinetuner, self).__init__(*args, **kwargs)
 
-	@property
-	def mpi(self):
-		return self.comm is not None
-
 	@property
 	def mpi_main_process(self):
 		return not (self.comm is not None and self.comm.rank != 0)
 
-
 	@property
 	def no_observe(self):
 		return self.no_sacred or not self.mpi_main_process
 
-	def gpu_config(self, devices):
-
-		if not self.mpi:
+	def check_mpi(self):
+		if self.comm is None:
 			msg = "Using MPIFinetuner without setting a communicator!"
 			warnings.warn(msg)
 			logging.warn(msg)
+			return False
+
+		return True
+
+	def gpu_config(self, devices):
+
+		if not self.check_mpi():
 			return super(MPIFinetuner, self).gpu_config(devices)
 
-		if len(devices) > 1:
-			self.device_id = devices[self.comm.rank]
+		if len(devices) == 1:
+			self.device_id = devices[0] + self.comm.intra_rank
 		else:
-			self.device_id += self.comm.intra_rank
+			self.device_id = devices[self.comm.rank]
 
 		device = self.init_device()
 		ranks = f"{self.comm.rank} | {self.comm.intra_rank} | {self.comm.inter_rank}"
@@ -45,7 +46,7 @@ class MPIFinetuner(DefaultFinetuner):
 
 
 	def scatter_datasets(self):
-		if self.mpi:
+		if self.check_mpi():
 			self.train_data = scatter(self.train_data, self.comm)
 			self.val_data = scatter(self.val_data, self.comm)
 		else:
@@ -61,18 +62,22 @@ class MPIFinetuner(DefaultFinetuner):
 
 		self.scatter_datasets()
 
-	def init_optimizer(self, opts):
-		super(MPIFinetuner, self).init_optimizer(opts)
+	def init_optimizer(self):
+		super(MPIFinetuner, self).init_optimizer()
+
+		if not self.check_mpi():
+			return
 
-		if self.mpi:
-			self.opt = chainermn.create_multi_node_optimizer(self.opt, self.comm)
+		self.opt = chainermn.create_multi_node_optimizer(self.opt, self.comm)
 
 	def init_evaluator(self):
 		super(MPIFinetuner, self).init_evaluator()
 
-		if self.mpi:
-			self.evaluator = chainermn.create_multi_node_evaluator(
-				self.evaluator, self.comm)
+		if not self.check_mpi():
+			return
+
+		self.evaluator = chainermn.create_multi_node_evaluator(
+			self.evaluator, self.comm)
 
 	def run(self, trainer_cls, opts, *args, **kwargs):
 		if not self.mpi_main_process: