mpi.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import chainermn
  2. from chainermn import scatter_dataset as scatter
  3. from .base import DefaultFinetuner
  4. class MPIFinetuner(DefaultFinetuner):
  5. @property
  6. def mpi(self):
  7. return self.comm is not None
  8. @property
  9. def mpi_main_process(self):
  10. return not self.mpi or self.comm.rank == 0
  11. def gpu_config(self, opts, comm=None):
  12. super(MPIFinetuner, self).gpu_config(opts)
  13. self.comm = comm
  14. if self.mpi:
  15. self.device = opts.gpu[self.comm.rank]
  16. # self.device += self.comm.intra_rank
  17. def scatter_datasets(self):
  18. if self.mpi:
  19. self.train_data = scatter(self.train_data, self.comm)
  20. self.val_data = scatter(self.val_data, self.comm)
  21. def init_datasets(self, *args, **kwargs):
  22. if not self.mpi_main_process:
  23. self.train_data, self.val_data = None, None
  24. return
  25. super(MPIFinetuner, self).init_datasets(*args, **kwargs)
  26. self.scatter_datasets()
  27. def init_optimizer(self, opts):
  28. super(MPIFinetuner, self).init_optimizer(opts)
  29. if self.mpi:
  30. self.opt = chainermn.create_multi_node_optimizer(self.opt, self.comm)
  31. def init_evaluator(self):
  32. super(MPIFinetuner, self).init_evaluator()
  33. if self.mpi:
  34. self.evaluator = chainermn.create_multi_node_evaluator(
  35. self.evaluator, self.comm)
  36. def run(self, opts, ex):
  37. super(MPIFinetuner, self).run(opts, ex, no_observe=not self.mpi_main_process)