mpi.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from .base import BaseFinetuner
  2. class MPIFinetuner(BaseFinetuner):
  3. @property
  4. def mpi(self):
  5. return self.comm is not None
  6. @property
  7. def mpi_main_process(self):
  8. return not self.mpi or self.comm.rank == 0
  9. def gpu_config(self, opts, comm=None):
  10. super(MPIFinetuner, self).gpu_config(opts)
  11. self.comm = comm
  12. if self.mpi:
  13. self.device = opts.gpu[self.comm.rank]
  14. # self.device += self.comm.intra_rank
  15. def scatter_datasets(self):
  16. if self.mpi:
  17. from chainermn import scatter_dataset as scatter
  18. self.train_data = scatter(self.train_data, self.comm)
  19. self.val_data = scatter(self.val_data, self.comm)
  20. def init_datasets(self, *args, **kwargs):
  21. if not self.mpi_main_process:
  22. self.train_data, self.val_data = None, None
  23. return
  24. super(MPIFinetuner, self).init_datasets(*args, **kwargs)
  25. self.scatter_datasets()
  26. def init_optimizer(self, opts):
  27. super(MPIFinetuner, self).init_optimizer(opts)
  28. if self.mpi:
  29. import chainermn
  30. self.opt = chainermn.create_multi_node_optimizer(self.opt, self.comm)
  31. def init_evaluator(self):
  32. super(MPIFinetuner, self).init_evaluator()
  33. if self.mpi:
  34. import chainermn
  35. self.evaluator = chainermn.create_multi_node_evaluator(
  36. self.evaluator, self.comm)
  37. def run(self, opts, ex):
  38. super(MPIFinetuner, self).run(opts, ex, no_observe=not self.mpi_main_process)