|
@@ -19,7 +19,7 @@ class MPIFinetuner(DefaultFinetuner):
|
|
|
return not (self.comm is not None and self.comm.rank != 0)
|
|
|
|
|
|
def gpu_config(self, opts):
|
|
|
- super(MPIFinetuner, self).gpu_config(opts)
|
|
|
+ device = super(MPIFinetuner, self).gpu_config(opts)
|
|
|
|
|
|
if self.mpi:
|
|
|
if len(opts.gpu) > 1:
|
|
@@ -27,12 +27,13 @@ class MPIFinetuner(DefaultFinetuner):
|
|
|
else:
|
|
|
self.device_id += self.comm.intra_rank
|
|
|
|
|
|
+ device = self.init_device()
|
|
|
ranks = f"{self.comm.rank}|{self.comm.intra_rank}|{self.comm.inter_rank}"
|
|
|
- logging.info(f"Node with ranks {ranks} assigned to GPU #{self.device}")
|
|
|
+ logging.info(f"Node with ranks {ranks} assigned to {device}")
|
|
|
else:
|
|
|
logging.warn("Using MPIFinetuner without setting a communicator!")
|
|
|
|
|
|
- return self.init_device()
|
|
|
+ return device
|
|
|
|
|
|
|
|
|
def scatter_datasets(self):
|