|
@@ -12,31 +12,32 @@ class MPIFinetuner(DefaultFinetuner):
|
|
|
self.comm = comm
|
|
|
super(MPIFinetuner, self).__init__(*args, **kwargs)
|
|
|
|
|
|
- @property
|
|
|
- def mpi(self):
|
|
|
- return self.comm is not None
|
|
|
-
|
|
|
@property
|
|
|
def mpi_main_process(self):
|
|
|
return not (self.comm is not None and self.comm.rank != 0)
|
|
|
|
|
|
-
|
|
|
@property
|
|
|
def no_observe(self):
|
|
|
return self.no_sacred or not self.mpi_main_process
|
|
|
|
|
|
- def gpu_config(self, devices):
|
|
|
-
|
|
|
- if not self.mpi:
|
|
|
+ def check_mpi(self):
|
|
|
+ if self.comm is None:
|
|
|
msg = "Using MPIFinetuner without setting a communicator!"
|
|
|
warnings.warn(msg)
|
|
|
logging.warn(msg)
|
|
|
+ return False
|
|
|
+
|
|
|
+ return True
|
|
|
+
|
|
|
+ def gpu_config(self, devices):
|
|
|
+
|
|
|
+ if not self.check_mpi():
|
|
|
return super(MPIFinetuner, self).gpu_config(devices)
|
|
|
|
|
|
- if len(devices) > 1:
|
|
|
- self.device_id = devices[self.comm.rank]
|
|
|
+ if len(devices) == 1:
|
|
|
+ self.device_id = devices[0] + self.comm.intra_rank
|
|
|
else:
|
|
|
- self.device_id += self.comm.intra_rank
|
|
|
+ self.device_id = devices[self.comm.rank]
|
|
|
|
|
|
device = self.init_device()
|
|
|
ranks = f"{self.comm.rank} | {self.comm.intra_rank} | {self.comm.inter_rank}"
|
|
@@ -45,7 +46,7 @@ class MPIFinetuner(DefaultFinetuner):
|
|
|
|
|
|
|
|
|
def scatter_datasets(self):
|
|
|
- if self.mpi:
|
|
|
+ if self.check_mpi():
|
|
|
self.train_data = scatter(self.train_data, self.comm)
|
|
|
self.val_data = scatter(self.val_data, self.comm)
|
|
|
else:
|
|
@@ -61,18 +62,22 @@ class MPIFinetuner(DefaultFinetuner):
|
|
|
|
|
|
self.scatter_datasets()
|
|
|
|
|
|
- def init_optimizer(self, opts):
|
|
|
- super(MPIFinetuner, self).init_optimizer(opts)
|
|
|
+ def init_optimizer(self):
|
|
|
+ super(MPIFinetuner, self).init_optimizer()
|
|
|
+
|
|
|
+ if not self.check_mpi():
|
|
|
+ return
|
|
|
|
|
|
- if self.mpi:
|
|
|
- self.opt = chainermn.create_multi_node_optimizer(self.opt, self.comm)
|
|
|
+ self.opt = chainermn.create_multi_node_optimizer(self.opt, self.comm)
|
|
|
|
|
|
def init_evaluator(self):
|
|
|
super(MPIFinetuner, self).init_evaluator()
|
|
|
|
|
|
- if self.mpi:
|
|
|
- self.evaluator = chainermn.create_multi_node_evaluator(
|
|
|
- self.evaluator, self.comm)
|
|
|
+ if not self.check_mpi():
|
|
|
+ return
|
|
|
+
|
|
|
+ self.evaluator = chainermn.create_multi_node_evaluator(
|
|
|
+ self.evaluator, self.comm)
|
|
|
|
|
|
def run(self, trainer_cls, opts, *args, **kwargs):
|
|
|
if not self.mpi_main_process:
|