|
|
@@ -18,9 +18,10 @@ class MPIFinetuner(DefaultFinetuner):
|
|
|
|
|
|
self.comm = comm
|
|
|
if self.mpi:
|
|
|
- self.device = opts.gpu[self.comm.rank]
|
|
|
-
|
|
|
- # self.device += self.comm.intra_rank
|
|
|
+ if len(opts.gpu) > 1:
|
|
|
+ self.device = opts.gpu[self.comm.rank]
|
|
|
+ else:
|
|
|
+ self.device += self.comm.intra_rank
|
|
|
|
|
|
def scatter_datasets(self):
|
|
|
if self.mpi:
|