Просмотр исходного кода

minor changes in argument handling of the MPIFinetuner

Dimitri Korsch 5 лет назад
Родитель
Сommit
d4d2e3f32b
2 измененных файлов с 9 добавлено и 15 удалено
  1. 2 2
      cvfinetune/finetuner/base.py
  2. 7 13
      cvfinetune/finetuner/mpi.py

+ 2 - 2
cvfinetune/finetuner/base.py

@@ -350,7 +350,7 @@ class DefaultFinetuner(_ModelMixin, _DatasetMixin, _TrainerMixin):
 	def __init__(self, opts, *args, **kwargs):
 		super(DefaultFinetuner, self).__init__(*args, **kwargs)
 
-		self.gpu_config(opts, *args, **kwargs)
+		self.gpu_config(opts)
 		cuda.get_device_from_id(self.device).use()
 
 		self.init_annotations(opts)
@@ -366,7 +366,7 @@ class DefaultFinetuner(_ModelMixin, _DatasetMixin, _TrainerMixin):
 		self.init_updater()
 		self.init_evaluator()
 
-	def gpu_config(self, opts, *args, **kwargs):
+	def gpu_config(self, opts):
 		if -1 in opts.gpu:
 			self.device = -1
 		else:

+ 7 - 13
cvfinetune/finetuner/mpi.py

@@ -2,18 +2,13 @@ import chainermn
 import logging
 from chainermn import scatter_dataset as scatter
 
-from .base import DefaultFinetuner
+from cvfinetune.finetuner.base import DefaultFinetuner
 
-class _mpi_mixin(object):
-	"""
-		This mixin is used to remove "comm" argument from
-		argument lists, so that object class gets an empty list
-	"""
+class MPIFinetuner(DefaultFinetuner):
 
-	def __init__(self, comm, *args, **kwargs):
-		super(_mpi_mixin, self).__init__(*args, **kwargs)
-
-class MPIFinetuner(DefaultFinetuner, _mpi_mixin):
+	def __init__(self, opts, *args, comm, **kwargs):
+		self.comm = comm
+		super(MPIFinetuner, self).__init__(opts, *args, **kwargs)
 
 	@property
 	def mpi(self):
@@ -23,10 +18,9 @@ class MPIFinetuner(DefaultFinetuner, _mpi_mixin):
 	def mpi_main_process(self):
 		return not (self.comm is not None and self.comm.rank != 0)
 
-	def gpu_config(self, opts, comm=None, *args, **kwargs):
-		super(MPIFinetuner, self).gpu_config(opts, *args, **kwargs)
+	def gpu_config(self, opts):
+		super(MPIFinetuner, self).gpu_config(opts)
 
-		self.comm = comm
 		if self.mpi:
 			if len(opts.gpu) > 1:
 				self.device = opts.gpu[self.comm.rank]