|
@@ -1,5 +1,7 @@
|
|
|
import chainermn
|
|
|
import logging
|
|
|
+import warnings
|
|
|
+
|
|
|
from chainermn import scatter_dataset as scatter
|
|
|
|
|
|
from cvfinetune.finetuner.base import DefaultFinetuner
|
|
@@ -19,20 +21,21 @@ class MPIFinetuner(DefaultFinetuner):
|
|
|
return not (self.comm is not None and self.comm.rank != 0)
|
|
|
|
|
|
def gpu_config(self, opts):
|
|
|
- device = super(MPIFinetuner, self).gpu_config(opts)
|
|
|
|
|
|
- if self.mpi:
|
|
|
- if len(opts.gpu) > 1:
|
|
|
- self.device_id = opts.gpu[self.comm.rank]
|
|
|
- else:
|
|
|
- self.device_id += self.comm.intra_rank
|
|
|
-
|
|
|
- device = self.init_device()
|
|
|
- ranks = f"{self.comm.rank}|{self.comm.intra_rank}|{self.comm.inter_rank}"
|
|
|
- logging.info(f"Node with ranks {ranks} assigned to {device}")
|
|
|
+ if not self.mpi:
|
|
|
+ msg = "Using MPIFinetuner without setting a communicator!"
|
|
|
+ warnings.warn(msg)
|
|
|
+ logging.warn(msg)
|
|
|
+ return super(MPIFinetuner, self).gpu_config(opts)
|
|
|
+
|
|
|
+ if len(opts.gpu) > 1:
|
|
|
+ self.device_id = opts.gpu[self.comm.rank]
|
|
|
else:
|
|
|
- logging.warn("Using MPIFinetuner without setting a communicator!")
|
|
|
+ self.device_id += self.comm.intra_rank
|
|
|
|
|
|
+ device = self.init_device()
|
|
|
+ ranks = f"{self.comm.rank} | {self.comm.intra_rank} | {self.comm.inter_rank}"
|
|
|
+ logging.info(f"Node with ranks {ranks} assigned to {device}")
|
|
|
return device
|
|
|
|
|
|
|