|
@@ -23,14 +23,18 @@ class MPIFinetuner(DefaultFinetuner):
|
|
|
|
|
|
if self.mpi:
|
|
|
if len(opts.gpu) > 1:
|
|
|
- self.device = opts.gpu[self.comm.rank]
|
|
|
+ self.device_id = opts.gpu[self.comm.rank]
|
|
|
else:
|
|
|
- self.device += self.comm.intra_rank
|
|
|
+ self.device_id += self.comm.intra_rank
|
|
|
+
|
|
|
ranks = f"{self.comm.rank}|{self.comm.intra_rank}|{self.comm.inter_rank}"
|
|
|
logging.info(f"Node with ranks {ranks} assigned to GPU #{self.device}")
|
|
|
else:
|
|
|
logging.warn("Using MPIFinetuner without setting a communicator!")
|
|
|
|
|
|
+ return self.init_device()
|
|
|
+
|
|
|
+
|
|
|
def scatter_datasets(self):
|
|
|
if self.mpi:
|
|
|
self.train_data = scatter(self.train_data, self.comm)
|