فهرست منبع

fixed MPI problems

Dimitri Korsch 6 سال پیش
والد
کامیت
2429ab0d63
4فایلهای تغییر یافته به همراه66 افزوده شده و 23 حذف شده
  1. 6 2
      cvfinetune/finetuner/base.py
  2. 28 10
      cvfinetune/finetuner/mpi.py
  3. 27 1
      cvfinetune/parser.py
  4. 5 10
      cvfinetune/training/trainer.py

+ 6 - 2
cvfinetune/finetuner/base.py

@@ -205,10 +205,11 @@ class _DatasetMixin(abc.ABC):
 		if opts.only_head:
 			self.annot.feature_model = opts.model_type
 
-	def init_datasets(self, opts):
-
 		self.dataset_cls.label_shift = opts.label_shift
 
+
+	def init_datasets(self, opts):
+
 		size = self.model.meta.input_size
 
 		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
@@ -281,6 +282,9 @@ class _TrainerMixin(abc.ABC):
 			weights=self.weights,
 			*args, **kwargs
 		)
+
+		logging.info("Snapshotting is {}abled".format("dis" if opts.no_snapshot else "en"))
+
 		def dump(suffix):
 			if opts.only_eval or opts.no_snapshot:
 				return

+ 28 - 10
cvfinetune/finetuner/mpi.py

@@ -1,9 +1,19 @@
 import chainermn
+import logging
 from chainermn import scatter_dataset as scatter
 
 from .base import DefaultFinetuner
 
-class MPIFinetuner(DefaultFinetuner):
+class _mpi_mixin(object):
+	"""
+		This mixin is used to remove "comm" argument from
+		argument lists, so that object class gets an empty list
+	"""
+
+	def __init__(self, comm, *args, **kwargs):
+		super(_mpi_mixin, self).__init__(*args, **kwargs)
+
+class MPIFinetuner(DefaultFinetuner, _mpi_mixin):
 
 	@property
 	def mpi(self):
@@ -11,10 +21,10 @@ class MPIFinetuner(DefaultFinetuner):
 
 	@property
 	def mpi_main_process(self):
-		return not self.mpi or self.comm.rank == 0
+		return not (self.comm is not None and self.comm.rank != 0)
 
-	def gpu_config(self, opts, comm=None):
-		super(MPIFinetuner, self).gpu_config(opts)
+	def gpu_config(self, opts, comm=None, *args, **kwargs):
+		super(MPIFinetuner, self).gpu_config(opts, *args, **kwargs)
 
 		self.comm = comm
 		if self.mpi:
@@ -22,19 +32,25 @@ class MPIFinetuner(DefaultFinetuner):
 				self.device = opts.gpu[self.comm.rank]
 			else:
 				self.device += self.comm.intra_rank
+			ranks = f"{self.comm.rank}|{self.comm.intra_rank}|{self.comm.inter_rank}"
+			logging.info(f"Node with ranks {ranks} assigned to GPU #{self.device}")
+		else:
+			logging.warn("Using MPIFinetuner without setting a communicator!")
 
 	def scatter_datasets(self):
 		if self.mpi:
 			self.train_data = scatter(self.train_data, self.comm)
 			self.val_data = scatter(self.val_data, self.comm)
+		else:
+			logging.warn("Data scattering was not Possible!")
+
 
 	def init_datasets(self, *args, **kwargs):
 
-		if not self.mpi_main_process:
+		if self.mpi_main_process:
+			super(MPIFinetuner, self).init_datasets(*args, **kwargs)
+		else:
 			self.train_data, self.val_data = None, None
-			return
-
-		super(MPIFinetuner, self).init_datasets(*args, **kwargs)
 
 		self.scatter_datasets()
 
@@ -51,5 +67,7 @@ class MPIFinetuner(DefaultFinetuner):
 			self.evaluator = chainermn.create_multi_node_evaluator(
 				self.evaluator, self.comm)
 
-	def run(self, opts, ex):
-		super(MPIFinetuner, self).run(opts, ex, no_observe=not self.mpi_main_process)
+	def run(self, trainer_cls, opts, *args, **kwargs):
+		kwargs["no_observe"] = not self.mpi_main_process
+		opts.no_snapshot = not self.mpi_main_process
+		super(MPIFinetuner, self).run(trainer_cls, opts, *args, **kwargs)

+ 27 - 1
cvfinetune/parser.py

@@ -1,9 +1,11 @@
 import os
+import logging
+import platform
 
 from chainer_addons.training import OptimizerType
 from chainer_addons.models import PrepareType
 
-from cvargparse import Arg, ArgFactory
+from cvargparse import GPUParser, Arg, ArgFactory
 from cvdatasets.utils import read_info_file
 
 DEFAULT_INFO_FILE=os.environ.get("DATA", "/home/korsch/Data/info.yml")
@@ -75,3 +77,27 @@ def default_factory(extra_list=[]):
 		.debug()\
 		.learning_rate(lr=1e-2, lrs=10, lrt=1e-5, lrd=1e-1)\
 		.weight_decay(default=5e-4)
+
+
+class FineTuneParser(GPUParser):
+	def init_logger(self, simple=False):
+		if not self.has_logging: return
+		fmt = '{levelname:s} - [{asctime:s}] {filename:s}:{lineno:d} [{funcName:s}]: {message:s}'
+
+		handler0 = logging.StreamHandler()
+		handler0.addFilter(HostnameFilter())
+		handler0.setFormatter(logging.Formatter("<{hostname:^10s}>: " + fmt, style="{"))
+
+		handler1 = logging.FileHandler(filename=f"{platform.node()}.log", mode="w")
+		handler1.setFormatter(logging.Formatter(fmt, style="{"))
+
+		logger = logging.getLogger()
+		logger.addHandler(handler0)
+		logger.addHandler(handler1)
+		logger.setLevel(getattr(logging, self.args.loglevel.upper(), logging.DEBUG))
+
+class HostnameFilter(logging.Filter):
+
+	def filter(self, record):
+		record.hostname = platform.node()
+		return True

+ 5 - 10
cvfinetune/training/trainer.py

@@ -47,8 +47,11 @@ class Trainer(T):
 		clf = optimizer.target
 		model = clf.model
 
-		outdir = self.output_directory(opts)
-		logging.info("Training outputs are saved under \"{}\"".format(outdir))
+		if no_observe:
+			outdir = opts.output
+		else:
+			outdir = self.output_directory(opts)
+			logging.info("Training outputs are saved under \"{}\"".format(outdir))
 
 		super(Trainer, self).__init__(
 			updater=updater,
@@ -146,14 +149,6 @@ class Trainer(T):
 			],
 		}
 
-		# if opts.triplet_loss:
-		# 	print_values.extend(["main/t_loss", eval_name("main/t_loss")])
-		# 	plot_values.update({
-		# 		"t_loss": [
-		# 			"main/t_loss", eval_name("main/t_loss"),
-		# 		]
-		# 	})
-
 		return print_values, plot_values