|
@@ -2,18 +2,13 @@ import chainermn
|
|
|
import logging
|
|
import logging
|
|
|
from chainermn import scatter_dataset as scatter
|
|
from chainermn import scatter_dataset as scatter
|
|
|
|
|
|
|
|
-from .base import DefaultFinetuner
|
|
|
|
|
|
|
+from cvfinetune.finetuner.base import DefaultFinetuner
|
|
|
|
|
|
|
|
-class _mpi_mixin(object):
|
|
|
|
|
- """
|
|
|
|
|
- This mixin is used to remove "comm" argument from
|
|
|
|
|
- argument lists, so that object class gets an empty list
|
|
|
|
|
- """
|
|
|
|
|
|
|
+class MPIFinetuner(DefaultFinetuner):
|
|
|
|
|
|
|
|
- def __init__(self, comm, *args, **kwargs):
|
|
|
|
|
- super(_mpi_mixin, self).__init__(*args, **kwargs)
|
|
|
|
|
-
|
|
|
|
|
-class MPIFinetuner(DefaultFinetuner, _mpi_mixin):
|
|
|
|
|
|
|
+ def __init__(self, opts, *args, comm, **kwargs):
|
|
|
|
|
+ self.comm = comm
|
|
|
|
|
+ super(MPIFinetuner, self).__init__(opts, *args, **kwargs)
|
|
|
|
|
|
|
|
@property
|
|
@property
|
|
|
def mpi(self):
|
|
def mpi(self):
|
|
@@ -23,10 +18,9 @@ class MPIFinetuner(DefaultFinetuner, _mpi_mixin):
|
|
|
def mpi_main_process(self):
|
|
def mpi_main_process(self):
|
|
|
return not (self.comm is not None and self.comm.rank != 0)
|
|
return not (self.comm is not None and self.comm.rank != 0)
|
|
|
|
|
|
|
|
- def gpu_config(self, opts, comm=None, *args, **kwargs):
|
|
|
|
|
- super(MPIFinetuner, self).gpu_config(opts, *args, **kwargs)
|
|
|
|
|
|
|
+ def gpu_config(self, opts):
|
|
|
|
|
+ super(MPIFinetuner, self).gpu_config(opts)
|
|
|
|
|
|
|
|
- self.comm = comm
|
|
|
|
|
if self.mpi:
|
|
if self.mpi:
|
|
|
if len(opts.gpu) > 1:
|
|
if len(opts.gpu) > 1:
|
|
|
self.device = opts.gpu[self.comm.rank]
|
|
self.device = opts.gpu[self.comm.rank]
|