Explorar el Código

reworked mpi checks and added in the utils some mpi functions

Dimitri Korsch hace 3 años
padre
commit
08a70353f9

+ 36 - 37
cvfinetune/finetuner/factory.py

@@ -1,54 +1,53 @@
 import logging
-try:
-    import chainermn
-except Exception as e: #pragma: no cover
-    _CHAINERMN_AVAILABLE = False #pragma: no cover
-else:
-    _CHAINERMN_AVAILABLE = True
+import warnings
 
 from cvfinetune import utils
+from cvfinetune.utils import mpi
 from cvfinetune.finetuner.base import DefaultFinetuner
 from cvfinetune.finetuner.mpi import MPIFinetuner
 
 from cvdatasets.utils import pretty_print_dict
 
-class FinetunerFactory(object):
+class FinetunerFactory:
 
-    @classmethod
-    def new(cls, *,
-            mpi: bool = False,
-            default=DefaultFinetuner,
-            mpi_tuner=MPIFinetuner,
-            **kwargs):
+	@classmethod
+	def new(cls, *args, **kwargs):
+		raise NotImplementedError("Use simple instance creation instead of {cls.__name__}.new()!")
 
-        if mpi:
-            assert _CHAINERMN_AVAILABLE, "Distributed training is not possible!"
-            msg1 = "MPI enabled. Creating NCCL communicator!"
-            comm = chainermn.create_communicator("pure_nccl")
-            msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
+	def __init__(self, *,
+		         default=DefaultFinetuner,
+		         mpi_tuner=MPIFinetuner,
+		         **kwargs):
+		super().__init__()
 
-            utils.log_messages([msg1, msg2])
-            return cls(mpi_tuner, comm=comm, **kwargs)
-        else:
-            return cls(default, **kwargs)
+		if "mpi" in kwargs:
+			kwargs.pop("mpi")
+			warnings.warn("\"mpi\" is no longer supported. MPI checks are performed automatically.", category=DeprecationWarning)
 
-    def __init__(self, tuner_cls, **kwargs):
-        super(FinetunerFactory, self).__init__()
+		self.kwargs = kwargs
+		self.tuner_cls = default
 
-        self.tuner_cls = tuner_cls
-        self.kwargs = kwargs
-        logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
+		if mpi.enabled() and mpi.chainermn_available():
+			comm = mpi.new_comm("pure_nccl")
+			msg1 = "MPI enabled. Creating NCCL communicator!"
+			msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
+			utils.log_messages(msg1, msg2)
 
-    def __call__(self, opts, **kwargs):
-        opt_kwargs = self.tuner_cls.extract_kwargs(opts)
-        _kwargs = dict(self.kwargs, **kwargs, **opt_kwargs)
-        return self.tuner_cls(config=opts.__dict__, **_kwargs)
+			self["comm"] = comm
+			self.tuner_cls = mpi_tuner
 
-    def get(self, key, default=None):
-        return self.kwargs.get(key, default)
+		logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
 
-    def __getitem__(self, key):
-        return self.kwargs[key]
+	def __call__(self, opts, **kwargs):
+		opt_kwargs = self.tuner_cls.extract_kwargs(opts)
+		_kwargs = dict(self.kwargs, **kwargs, **opt_kwargs)
+		return self.tuner_cls(config=opts.__dict__, **_kwargs)
 
-    def __setitem__(self, key, value):
-        self.kwargs[key] = value
+	def get(self, key, default=None):
+		return self.kwargs.get(key, default)
+
+	def __getitem__(self, key):
+		return self.kwargs[key]
+
+	def __setitem__(self, key, value):
+		self.kwargs[key] = value

+ 1 - 17
cvfinetune/utils/__init__.py

@@ -1,17 +1 @@
-import logging
-
-def log_messages(msg_list, *, n_chars=10, char="=", logger=None, level=logging.INFO):
-	"""
-		Adds <chars> and a space at each end of every message.
-		Adjusts the length of each message to the maximum length
-		of the messages in the list.
-	"""
-
-	max_len = max(map(len, msg_list))
-	fmt_len = max_len + 2*(n_chars+1)
-	fmt = "{:" + char + "^" + str(fmt_len) + "s}"
-
-	logger = logger or logging.getLogger()
-
-	for msg in msg_list:
-		logger.log(level, fmt.format(f" {msg} "))
+from cvfinetune.utils.logging import log_messages

+ 17 - 0
cvfinetune/utils/logging.py

@@ -0,0 +1,17 @@
+import logging
+
+def log_messages(*msg_list, n_chars=10, char="=", logger=None, level=logging.INFO):
+	"""
+		Adds <chars> and a space at each end of every message.
+		Adjusts the length of each message to the maximum length
+		of the messages in the list.
+	"""
+
+	max_len = max(map(len, msg_list))
+	fmt_len = max_len + 2*(n_chars+1)
+	fmt = "{:" + char + "^" + str(fmt_len) + "s}"
+
+	logger = logger or logging.getLogger()
+
+	for msg in msg_list:
+		logger.log(level, fmt.format(f" {msg} "))

+ 22 - 0
cvfinetune/utils/mpi.py

@@ -0,0 +1,22 @@
+import mpi4py.MPI as MPI
+
+try:
+	import chainermn
+except Exception as e: #pragma: no cover
+	_CHAINERMN_AVAILABLE = False #pragma: no cover
+else:
+	_CHAINERMN_AVAILABLE = True
+
+
+
+def chainermn_available(strict: bool = True) -> bool:
+	if strict:
+		assert _CHAINERMN_AVAILABLE, "Distributed training is not possible!"
+
+	return _CHAINERMN_AVAILABLE
+
+def enabled() -> bool:
+	return MPI.COMM_WORLD.Get_size() > 1
+
+def new_comm(comm_type: str = "pure_nccl"):
+	return chainermn.create_communicator("pure_nccl")