|
@@ -1,54 +1,53 @@
|
|
|
import logging
|
|
|
-try:
|
|
|
- import chainermn
|
|
|
-except Exception as e: #pragma: no cover
|
|
|
- _CHAINERMN_AVAILABLE = False #pragma: no cover
|
|
|
-else:
|
|
|
- _CHAINERMN_AVAILABLE = True
|
|
|
+import warnings
|
|
|
|
|
|
from cvfinetune import utils
|
|
|
+from cvfinetune.utils import mpi
|
|
|
from cvfinetune.finetuner.base import DefaultFinetuner
|
|
|
from cvfinetune.finetuner.mpi import MPIFinetuner
|
|
|
|
|
|
from cvdatasets.utils import pretty_print_dict
|
|
|
|
|
|
-class FinetunerFactory(object):
|
|
|
+class FinetunerFactory:
|
|
|
|
|
|
- @classmethod
|
|
|
- def new(cls, *,
|
|
|
- mpi: bool = False,
|
|
|
- default=DefaultFinetuner,
|
|
|
- mpi_tuner=MPIFinetuner,
|
|
|
- **kwargs):
|
|
|
+ @classmethod
|
|
|
+ def new(cls, *args, **kwargs):
|
|
|
+ raise NotImplementedError("Use simple instance creation instead of {cls.__name__}.new()!")
|
|
|
|
|
|
- if mpi:
|
|
|
- assert _CHAINERMN_AVAILABLE, "Distributed training is not possible!"
|
|
|
- msg1 = "MPI enabled. Creating NCCL communicator!"
|
|
|
- comm = chainermn.create_communicator("pure_nccl")
|
|
|
- msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
|
|
|
+ def __init__(self, *,
|
|
|
+ default=DefaultFinetuner,
|
|
|
+ mpi_tuner=MPIFinetuner,
|
|
|
+ **kwargs):
|
|
|
+ super().__init__()
|
|
|
|
|
|
- utils.log_messages([msg1, msg2])
|
|
|
- return cls(mpi_tuner, comm=comm, **kwargs)
|
|
|
- else:
|
|
|
- return cls(default, **kwargs)
|
|
|
+ if "mpi" in kwargs:
|
|
|
+ kwargs.pop("mpi")
|
|
|
+ warnings.warn("\"mpi\" is no longer supported. MPI checks are performed automatically.", category=DeprecationWarning)
|
|
|
|
|
|
- def __init__(self, tuner_cls, **kwargs):
|
|
|
- super(FinetunerFactory, self).__init__()
|
|
|
+ self.kwargs = kwargs
|
|
|
+ self.tuner_cls = default
|
|
|
|
|
|
- self.tuner_cls = tuner_cls
|
|
|
- self.kwargs = kwargs
|
|
|
- logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
|
|
|
+ if mpi.enabled() and mpi.chainermn_available():
|
|
|
+ comm = mpi.new_comm("pure_nccl")
|
|
|
+ msg1 = "MPI enabled. Creating NCCL communicator!"
|
|
|
+ msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
|
|
|
+ utils.log_messages(msg1, msg2)
|
|
|
|
|
|
- def __call__(self, opts, **kwargs):
|
|
|
- opt_kwargs = self.tuner_cls.extract_kwargs(opts)
|
|
|
- _kwargs = dict(self.kwargs, **kwargs, **opt_kwargs)
|
|
|
- return self.tuner_cls(config=opts.__dict__, **_kwargs)
|
|
|
+ self["comm"] = comm
|
|
|
+ self.tuner_cls = mpi_tuner
|
|
|
|
|
|
- def get(self, key, default=None):
|
|
|
- return self.kwargs.get(key, default)
|
|
|
+ logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
|
|
|
|
|
|
- def __getitem__(self, key):
|
|
|
- return self.kwargs[key]
|
|
|
+ def __call__(self, opts, **kwargs):
|
|
|
+ opt_kwargs = self.tuner_cls.extract_kwargs(opts)
|
|
|
+ _kwargs = dict(self.kwargs, **kwargs, **opt_kwargs)
|
|
|
+ return self.tuner_cls(config=opts.__dict__, **_kwargs)
|
|
|
|
|
|
- def __setitem__(self, key, value):
|
|
|
- self.kwargs[key] = value
|
|
|
+ def get(self, key, default=None):
|
|
|
+ return self.kwargs.get(key, default)
|
|
|
+
|
|
|
+ def __getitem__(self, key):
|
|
|
+ return self.kwargs[key]
|
|
|
+
|
|
|
+ def __setitem__(self, key, value):
|
|
|
+ self.kwargs[key] = value
|