Browse Source

refactored finetuner class mixins

Dimitri Korsch 3 years ago
parent
commit
1e900ed8e6

+ 13 - 35
cvfinetune/dataset.py

@@ -1,49 +1,27 @@
-import numpy as np
 import abc
 import abc
-
-from chainer_addons.dataset import AugmentationMixin
-from chainer_addons.dataset import PreprocessMixin
+import numpy as np
 
 
 from cvdatasets.dataset import AnnotationsReadMixin
 from cvdatasets.dataset import AnnotationsReadMixin
-from cvdatasets.dataset import RevealedPartMixin
 from cvdatasets.dataset import IteratorMixin
 from cvdatasets.dataset import IteratorMixin
+from cvdatasets.dataset import TransformMixin
+from cvdatasets.dataset import UniformPartMixin
 
 
-class _pre_augmentation_mixin(abc.ABC):
-	""" This mixin discards the parts from the ImageWrapper object
-	and shifts the labels
-	"""
-
-	label_shift = 1
-
-	def get_example(self, i):
-		im_obj = super(_pre_augmentation_mixin, self).get_example(i)
-		im, parts, lab = im_obj.as_tuple()
-		return im, lab + self.label_shift
-
-class _base_mixin(abc.ABC):
-	""" This mixin converts images,that are in range
-	[0..1] to the range [-1..1]
-	"""
+class BaseDataset(TransformMixin, UniformPartMixin, AnnotationsReadMixin):
+	"""Commonly used dataset constellation"""
 
 
-	def get_example(self, i):
-		im, lab = super(_base_mixin, self).get_example(i)
+	def __init__(self, *args, prepare, center_crop_on_val: bool = True, **kwargs):
+		super().__init__(*args, **kwargs)
+		self.prepare = prepare
 
 
+	def augment(self, im):
 		if isinstance(im, list):
 		if isinstance(im, list):
 			im = np.array(im)
 			im = np.array(im)
 
 
 		if np.logical_and(0 <= im, im <= 1).all():
 		if np.logical_and(0 <= im, im <= 1).all():
 			im = im * 2 -1
 			im = im * 2 -1
 
 
-		return im, lab
-
+		return im
 
 
-class BaseDataset(_base_mixin,
-	# augmentation and preprocessing
-	AugmentationMixin, PreprocessMixin,
-	_pre_augmentation_mixin,
-	# random uniform region selection
-	RevealedPartMixin,
-	# reads image
-	AnnotationsReadMixin,
-	IteratorMixin):
-	"""Commonly used dataset constellation"""
+	def transform(self, im_obj):
+		im, parts, lab = im_obj.as_tuple()
+		return self.prepare(im), lab + self.label_shift

+ 27 - 16
cvfinetune/finetuner/base.py

@@ -3,41 +3,52 @@ import logging
 
 
 from cvfinetune.finetuner import mixins
 from cvfinetune.finetuner import mixins
 
 
-class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._TrainerMixin):
+class DefaultFinetuner(
+	mixins._ModelMixin,
+	mixins._OptimizerMixin,
+	mixins._ClassifierMixin,
+	mixins._DatasetMixin,
+	mixins._IteratorMixin,
+	mixins._TrainerMixin):
 	""" The default Finetuner gathers together the creations of all needed
 	""" The default Finetuner gathers together the creations of all needed
 	components and call them in the correct order
 	components and call them in the correct order
 
 
 	"""
 	"""
 
 
-	def __init__(self, opts, *args, **kwargs):
-		super(DefaultFinetuner, self).__init__(opts=opts, *args, **kwargs)
+	def __init__(self, *args, gpu = [-1], **kwargs):
+		super().__init__(*args, **kwargs)
 
 
-		self.gpu_config(opts)
-		self.read_annotations(opts)
+		self.gpu_config(gpu)
+		self.read_annotations()
 
 
-		self.init_model(opts)
-		self.init_datasets(opts)
-		self.init_iterators(opts)
+		self.init_model()
+		self.init_datasets()
+		self.init_iterators()
 
 
-		self.init_classifier(opts)
-		self.load_weights(opts)
+		self.init_classifier()
+		self.load_weights()
 
 
-		self.init_optimizer(opts)
+		self.init_optimizer()
 		self.init_updater()
 		self.init_updater()
 		self.init_evaluator()
 		self.init_evaluator()
 
 
+
+	def _check_attr(self, attr_name, msg=None):
+		msg = msg or f"<{type(self).__name__}> {attr_name} attribute was not initialized!"
+		assert hasattr(self, attr_name), msg
+
 	def init_device(self):
 	def init_device(self):
 		self.device = chainer.get_device(self.device_id)
 		self.device = chainer.get_device(self.device_id)
 		self.device.use()
 		self.device.use()
 		return self.device
 		return self.device
 
 
-
-	def gpu_config(self, opts):
-		if -1 in opts.gpu:
+	def gpu_config(self, devices):
+		if -1 in devices:
 			self.device_id = -1
 			self.device_id = -1
 		else:
 		else:
-			self.device_id = opts.gpu[0]
+			self.device_id = devices[0]
 
 
+		device = self.init_device()
 		logging.info(f"Using device {device}")
 		logging.info(f"Using device {device}")
-		return self.init_device()
+		return device
 
 

+ 33 - 29
cvfinetune/finetuner/factory.py

@@ -1,10 +1,10 @@
 import logging
 import logging
 try:
 try:
-	import chainermn
+    import chainermn
 except Exception as e: #pragma: no cover
 except Exception as e: #pragma: no cover
-	_CHAINERMN_AVAILABLE = False #pragma: no cover
+    _CHAINERMN_AVAILABLE = False #pragma: no cover
 else:
 else:
-	_CHAINERMN_AVAILABLE = True
+    _CHAINERMN_AVAILABLE = True
 
 
 from cvfinetune import utils
 from cvfinetune import utils
 from cvfinetune.finetuner.base import DefaultFinetuner
 from cvfinetune.finetuner.base import DefaultFinetuner
@@ -14,37 +14,41 @@ from cvdatasets.utils import pretty_print_dict
 
 
 class FinetunerFactory(object):
 class FinetunerFactory(object):
 
 
-	@classmethod
-	def new(cls, opts, default=DefaultFinetuner, mpi_tuner=MPIFinetuner):
+    @classmethod
+    def new(cls, *,
+            mpi: bool = False,
+            default=DefaultFinetuner,
+            mpi_tuner=MPIFinetuner,
+            **kwargs):
 
 
-		if getattr(opts, "mpi", False):
-			assert _CHAINERMN_AVAILABLE, "Distributed training is not possible!"
-			msg1 = "MPI enabled. Creating NCCL communicator!"
-			comm = chainermn.create_communicator("pure_nccl")
-			msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
+        if mpi:
+            assert _CHAINERMN_AVAILABLE, "Distributed training is not possible!"
+            msg1 = "MPI enabled. Creating NCCL communicator!"
+            comm = chainermn.create_communicator("pure_nccl")
+            msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
 
 
-			utils.log_messages([msg1, msg2])
-			return cls(mpi_tuner, comm=comm)
-		else:
-			return cls(default)
+            utils.log_messages([msg1, msg2])
+            return cls(mpi_tuner, comm=comm, **kwargs)
+        else:
+            return cls(default, **kwargs)
 
 
-	def __init__(self, tuner_cls, **kwargs):
-		super(FinetunerFactory, self).__init__()
+    def __init__(self, tuner_cls, **kwargs):
+        super(FinetunerFactory, self).__init__()
 
 
-		self.tuner_cls = tuner_cls
-		self.kwargs = kwargs
-		logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
+        self.tuner_cls = tuner_cls
+        self.kwargs = kwargs
+        logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
 
 
-	def __call__(self, **kwargs):
-		_kwargs = dict(self.kwargs, **kwargs)
+    def __call__(self, opts, **kwargs):
+        opt_kwargs = self.tuner_cls.extract_kwargs(opts)
+        _kwargs = dict(self.kwargs, **kwargs, **opt_kwargs)
+        return self.tuner_cls(**_kwargs)
 
 
-		return self.tuner_cls(**_kwargs)
+    def get(self, key, default=None):
+        return self.kwargs.get(key, default)
 
 
-	def get(self, key, default=None):
-		return self.kwargs.get(key, default)
+    def __getitem__(self, key):
+        return self.kwargs[key]
 
 
-	def __getitem__(self, key):
-		return self.kwargs[key]
-
-	def __setitem__(self, key, value):
-		self.kwargs[key] = value
+    def __setitem__(self, key, value):
+        self.kwargs[key] = value

+ 6 - 0
cvfinetune/finetuner/mixins/__init__.py

@@ -1,10 +1,16 @@
 from cvfinetune.finetuner.mixins.dataset import _DatasetMixin
 from cvfinetune.finetuner.mixins.dataset import _DatasetMixin
+from cvfinetune.finetuner.mixins.classifier import _ClassifierMixin
 from cvfinetune.finetuner.mixins.model import _ModelMixin
 from cvfinetune.finetuner.mixins.model import _ModelMixin
+from cvfinetune.finetuner.mixins.optimizer import _OptimizerMixin
+from cvfinetune.finetuner.mixins.iterator import _IteratorMixin
 from cvfinetune.finetuner.mixins.trainer import _TrainerMixin
 from cvfinetune.finetuner.mixins.trainer import _TrainerMixin
 
 
 
 
 __all__ = [
 __all__ = [
 	"_DatasetMixin",
 	"_DatasetMixin",
+	"_ClassifierMixin",
 	"_ModelMixin",
 	"_ModelMixin",
+	"_OptimizerMixin",
+	"_IteratorMixin",
 	"_TrainerMixin",
 	"_TrainerMixin",
 ]
 ]

+ 71 - 0
cvfinetune/finetuner/mixins/base.py

@@ -0,0 +1,71 @@
+import abc
+import inspect
+
+
+class BaseMixin(abc.ABC):
+
+	def _after_init_check(self):
+		pass
+
+	@classmethod
+	def extract_kwargs(cls, opts) -> dict:
+
+		kwargs = {}
+
+		for klass in cls.mro():
+			sig = inspect.signature(klass.__init__)
+			for attr, param in sig.parameters.items():
+				if param.kind is not inspect.Parameter.KEYWORD_ONLY:
+					continue
+
+				if param.name in kwargs:
+					continue
+
+				if hasattr(opts, param.name):
+					value = getattr(opts, param.name)
+					kwargs[param.name] = value
+		return kwargs
+
+
+if __name__ == '__main__':
+
+	from collections import namedtuple
+	class Foo(BaseMixin):
+
+		@classmethod
+		def extract_kwargs(cls, opts) -> dict:
+			return super().extract_kwargs(opts)
+
+		def __init__(self, *args, foo, bar=0, **kwargs):
+			super().__init__(*args, **kwargs)
+			self.foo = foo
+			self.bar = bar
+
+
+	class Bar(BaseMixin):
+		@classmethod
+		def extract_kwargs(cls, opts) -> dict:
+			return super().extract_kwargs(opts)
+
+		def __init__(self, *args, bar2=-1, **kwargs):
+			super().__init__(*args, **kwargs)
+			self.bar2 = bar2
+
+
+	class Final(Bar, Foo):
+
+
+		def __init__(self, *args, beef=-1, **kwargs):
+			super().__init__(*args, **kwargs)
+			self.beef = beef
+
+		def __repr__(self):
+			return str(self.__dict__)
+
+
+	Opts = namedtuple("Opts", "foo foo2 bar bar2 beef1")
+
+	opts = Opts(1,2,3, -4, "hat")
+	kwargs = Final.extract_kwargs(opts)
+
+	print(opts, Final(**kwargs))

+ 67 - 0
cvfinetune/finetuner/mixins/classifier.py

@@ -0,0 +1,67 @@
+import abc
+import logging
+
+from chainer import functions as F
+from chainer_addons.functions import smoothed_cross_entropy
+from cvdatasets.utils import pretty_print_dict
+from functools import partial
+
+from cvfinetune.finetuner.mixins.base import BaseMixin
+
+class _ClassifierCreator:
+
+    def __init__(self, cls, **kwargs):
+        super().__init__()
+        self.cls = cls
+        self.kwargs = kwargs
+
+    def __call__(self, *args, **kwargs):
+        kwargs = dict(self.kwargs, **kwargs)
+        return self.cls(*args, **kwargs)
+
+class _ClassifierMixin(BaseMixin):
+    """
+        This mixin implements the wrapping of the backbone model around
+        a classifier instance.
+    """
+
+    def __init__(self, *args,
+                 classifier_cls,
+                 classifier_kwargs: dict = {},
+                 l1_loss: bool = False,
+                 label_smoothing: float = 0.0,
+                 **kwargs):
+
+        super().__init__(*args, **kwargs)
+        self._clf_creator = _ClassifierCreator(classifier_cls, **classifier_kwargs)
+
+        self._l1_loss = l1_loss
+        self._label_smoothing = label_smoothing
+
+
+    def init_classifier(self):
+        self._check_attr("model")
+        self._check_attr("n_classes")
+
+        self.clf = self._clf_creator(model=self.model,
+                                     loss_func=self.loss_func)
+
+        kwargs = self._clf_creator.kwargs
+        logging.info(
+            f"Wrapped the model around {type(self.clf).__name__}"
+            f" with kwargs: {pretty_print_dict(kwargs)}"
+        )
+
+    @property
+    def loss_func(self):
+        if self._l1_loss:
+            return F.hinge
+
+        if self._label_smoothing > 0:
+            assert self._label_smoothing < 1, "Label smoothing factor must be less than 1!"
+
+            return partial(smoothed_cross_entropy,
+                           N=self.n_classes,
+                           eps=self._label_smoothing)
+
+        return F.softmax_cross_entropy

+ 73 - 68
cvfinetune/finetuner/mixins/dataset.py

@@ -1,96 +1,101 @@
 import abc
 import abc
 import logging
 import logging
+import typing as T
 
 
+from collections import namedtuple
 from cvdatasets import AnnotationType
 from cvdatasets import AnnotationType
 from cvdatasets.dataset.image import Size
 from cvdatasets.dataset.image import Size
-from cvdatasets.utils import new_iterator
 
 
+from cvfinetune.finetuner.mixins.base import BaseMixin
 
 
-class _DatasetMixin(abc.ABC):
-	"""
-		This mixin is responsible for annotation loading and for
-		dataset and iterator creation.
-	"""
 
 
-	def __init__(self, opts, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
-		super(_DatasetMixin, self).__init__(opts=opts, *args, **kwargs)
-		self.annot = None
-		self.dataset_type = opts.dataset
-		self.dataset_cls = dataset_cls
-		self.dataset_kwargs_factory = dataset_kwargs_factory
+class _DatasetMixin(BaseMixin):
+    """
+        This mixin is responsible for annotation loading and for
+        dataset and iterator creation.
+    """
 
 
-	@property
-	def n_classes(self):
-		return self.ds_info.n_classes + self.dataset_cls.label_shift
+    def __init__(self,
+                 *args,
+                 data: str,
+                 dataset: str,
+                 dataset_cls: T.Type,
+                 dataset_kwargs_factory: T.Optional[T.Callable] = None,
 
 
-	@property
-	def data_info(self):
-		assert self.annot is not None, "annot attribute was not set!"
-		return self.annot.info
+                 label_shift: int = 0,
+                 input_size: int = 224,
+                 part_input_size:  T.Optional[int] = None,
+                 **kwargs):
 
 
-	@property
-	def ds_info(self):
-		return self.data_info.DATASETS[self.dataset_type]
+        super().__init__(*args, **kwargs)
+        self.annot = None
+        self.info_file = data
+        self.dataset_name = dataset
+        self.dataset_cls = dataset_cls
+        self.dataset_kwargs_factory = dataset_kwargs_factory
 
 
-	def new_dataset(self, opts, size, part_size, subset):
-		"""Creates a dataset for a specific subset and certain options"""
-		if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
-			kwargs = self.dataset_kwargs_factory(opts, subset)
-		else:
-			kwargs = dict()
+        self.input_size = Size(input_size)
 
 
-		kwargs = dict(kwargs,
-			subset=subset,
-			dataset_cls=self.dataset_cls,
-			prepare=self.prepare,
-			size=size,
-			part_size=part_size,
-			center_crop_on_val=getattr(opts, "center_crop_on_val", False),
-		)
+        if part_input_size is None:
+            self.part_input_size = self.input_size
 
 
+        else:
+            self.part_input_size = Size(self.part_input_size)
 
 
-		ds = self.annot.new_dataset(**kwargs)
-		logging.info("Loaded {} images".format(len(ds)))
-		return ds
+        self._label_shift = label_shift
 
 
 
 
-	def read_annotations(self, opts):
-		"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
+    def read_annotations(self):
+        """Reads annotations and creates annotation instance, which holds important infos about the dataset"""
+        opts = namedtuple("Opt", "data dataset")(self.info_file, self.dataset_name)
+        self.annot = AnnotationType.new_annotation(opts, load_strict=False)
+        self.dataset_cls.label_shift = self._label_shift
 
 
-		self.annot = AnnotationType.new_annotation(opts, load_strict=False)
-		self.dataset_cls.label_shift = opts.label_shift
+    def init_datasets(self):
+        self._check_attr("prepare")
+        self._check_attr("_center_crop_on_val")
 
 
+        logging.info(" | ".join([
+            f"Image input size: {self.input_size}",
+            f"Parts input size: {self.part_input_size}",
+        ]))
 
 
-	def init_datasets(self, opts):
+        self.train_data = self.new_dataset("train")
+        self.val_data = self.new_dataset("test")
 
 
-		size = Size(opts.input_size)
-		part_size = getattr(opts, "parts_input_size", None)
-		part_size = size if part_size is None else Size(part_size)
 
 
-		logging.info(" ".join([
-			f"Image input size: {size}",
-			f"Image parts input size: {part_size}",
-		]))
+    @property
+    def n_classes(self):
+        return self.ds_info.n_classes + self._label_shift
 
 
-		self.train_data = self.new_dataset(opts, size, part_size, "train")
-		self.val_data = self.new_dataset(opts, size, part_size, "test")
+    @property
+    def data_info(self):
+        assert self.annot is not None, "annot attribute was not set!"
+        return self.annot.info
 
 
-	def init_iterators(self, opts):
-		"""Creates training and validation iterators from training and validation datasets"""
+    @property
+    def ds_info(self):
+        return self.data_info.DATASETS[self.dataset_name]
 
 
-		kwargs = dict(n_jobs=opts.n_jobs, batch_size=opts.batch_size)
+    def new_dataset(self, subset: str):
+        """Creates a dataset for a specific subset and certain options"""
+        if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
+            kwargs = self.dataset_kwargs_factory(subset)
+        else:
+            kwargs = dict()
 
 
-		if hasattr(self.train_data, "new_iterator"):
-			self.train_iter, _ = self.train_data.new_iterator(**kwargs)
-		else:
-			self.train_iter, _ = new_iterator(self.train_data, **kwargs)
+        kwargs = dict(kwargs,
+            subset=subset,
+            dataset_cls=self.dataset_cls,
+            prepare=self.prepare,
+            size=self.input_size,
+            part_size=self.part_input_size,
+            center_crop_on_val=self._center_crop_on_val,
+        )
+
+
+        ds = self.annot.new_dataset(**kwargs)
+        logging.info(f"Loaded {len(ds)} images")
+        return ds
 
 
-		if hasattr(self.val_data, "new_iterator"):
-			self.val_iter, _ = self.val_data.new_iterator(**kwargs,
-				repeat=False, shuffle=False
-			)
-		else:
-			self.val_iter, _ = new_iterator(self.val_data,
-				**kwargs, repeat=False, shuffle=False
-			)
 
 

+ 41 - 0
cvfinetune/finetuner/mixins/iterator.py

@@ -0,0 +1,41 @@
+import abc
+import logging
+import typing as T
+
+from cvdatasets.utils import new_iterator
+
+from cvfinetune.finetuner.mixins.base import BaseMixin
+
+class _IteratorMixin(BaseMixin):
+
+    def __init__(self,
+                 *args,
+                 batch_size: int = 32,
+                 n_jobs: int = 1,
+                 **kwargs):
+    	super().__init__(*args, **kwargs)
+
+    	self._batch_size = batch_size
+    	self._n_jobs = n_jobs
+
+
+    def new_iterator(self, ds, **kwargs):
+    	if hasattr(ds, "new_iterator"):
+    		return ds.new_iterator(**kwargs)
+    	else:
+    		return new_iterator(ds, **kwargs)
+
+    def init_iterators(self):
+        """Creates training and validation iterators from training and validation datasets"""
+
+        self._check_attr("val_data")
+        self._check_attr("train_data")
+
+        kwargs = dict(n_jobs=self._n_jobs, batch_size=self._batch_size)
+
+        self.train_iter, _ = self.new_iterator(self.train_data,
+        	                                   **kwargs)
+
+        self.val_iter, _ = self.new_iterator(self.val_data,
+        	                                 repeat=False, shuffle=False,
+        	                                 **kwargs)

+ 113 - 151
cvfinetune/finetuner/mixins/model.py

@@ -3,207 +3,169 @@ import chainer
 import logging
 import logging
 
 
 from chainer import functions as F
 from chainer import functions as F
-from chainer.optimizer_hooks import Lasso
-from chainer.optimizer_hooks import WeightDecay
-from chainer_addons.functions import smoothed_cross_entropy
 from chainer_addons.models import PrepareType
 from chainer_addons.models import PrepareType
-from chainer_addons.training import optimizer
-from chainer_addons.training import optimizer_hooks
 from chainercv2.models import model_store
 from chainercv2.models import model_store
 from cvdatasets.dataset.image import Size
 from cvdatasets.dataset.image import Size
-from cvdatasets.utils import pretty_print_dict
 from cvmodelz.models import ModelFactory
 from cvmodelz.models import ModelFactory
 from functools import partial
 from functools import partial
 from pathlib import Path
 from pathlib import Path
 from typing import Tuple
 from typing import Tuple
 
 
-def check_param_for_decay(param):
-	return param.name != "alpha"
+from cvfinetune.finetuner.mixins.base import BaseMixin
 
 
-def enable_only_head(chain: chainer.Chain):
-	if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
-		chain.enable_only_head()
+class _ModelMixin(BaseMixin):
+    """
+        This mixin is responsible for model selection, model and optimizer creation,
+        and model weights loading.
+    """
 
 
-	else:
-		chain.disable_update()
-		chain.fc.enable_update()
+    def __init__(self, *args,
+                 model_type: str,
+                 model_kwargs: dict = {},
+                 pooling: str = "g_avg",
 
 
+                 prepare_type: str = "model",
+                 center_crop_on_val: bool = True,
+                 swap_channels: bool = False,
 
 
-class _ModelMixin(abc.ABC):
-	"""
-		This mixin is responsible for optimizer creation, model creation,
-		model wrapping around a classifier and model weights loading.
-	"""
+                 load: str = None,
+                 weights: str = None,
+                 load_path: str = "",
+                 load_strict: bool = False,
+                 load_headless: bool = False,
+                 pretrained_on: str = "imagenet",
 
 
-	def __init__(self, opts, classifier_cls, classifier_kwargs={}, model_kwargs={}, *args, **kwargs):
-		super(_ModelMixin, self).__init__(opts=opts, *args, **kwargs)
-		self.classifier_cls = classifier_cls
-		self.classifier_kwargs = classifier_kwargs
-		self.model_type = opts.model_type
-		self.model_kwargs = model_kwargs
+                 from_scratch: bool = False,
+                 **kwargs):
+        super().__init__(*args, **kwargs)
 
 
+        self.model_type = model_type
+        self.model_kwargs = model_kwargs
 
 
-	@property
-	def model_info(self):
-		return self.data_info.MODELS[self.model_type]
+        self._center_crop_on_val = center_crop_on_val
+        self._swap_channels = swap_channels
 
 
-	def init_model(self, opts):
-		"""creates backbone CNN model. This model is wrapped around the classifier later"""
+        if model_type.startswith("chainercv2"):
+            if prepare_type != "chainercv2":
+                msg = f"Using chainercv2 model, but prepare_type was set to \"{prepare_type}\". "
+                "Setting it to \"chainercv2\"!"
+                warnings.warn(msg)
+            prepare_type = "chainercv2"
 
 
-		self.model = ModelFactory.new(self.model_type,
-			input_size=Size(opts.input_size),
-			**self.model_kwargs
-		)
+        self._prepare_type = prepare_type
+        self._pooling = pooling
 
 
+        self._load = load
+        self._weights = weights
+        self._from_scratch = from_scratch
+        self._load_path = load_path
+        self._load_strict = load_strict
+        self._load_headless = load_headless
+        self._pretrained_on = pretrained_on
 
 
-		if self.model_type.startswith("chainercv2"):
-			opts.prepare_type = "chainercv2"
 
 
-		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
-			swap_channels=opts.swap_channels,
-			keep_ratio=getattr(opts, "center_crop_on_val", False),
-		)
+    def init_model(self):
+        """creates backbone CNN model. This model is wrapped around the classifier later"""
 
 
-		logging.info(
-			f"Created {self.model.__class__.__name__} model "
-			f" with \"{opts.prepare_type}\" prepare function."
-		)
+        self._check_attr("input_size")
 
 
+        self.model = self.new_model()
 
 
-	def init_classifier(self, opts):
+        logging.info(
+            f"Created {type(self.model).__name__} model "
+            f" with \"{self._prepare_type}\" prepare function."
+        )
 
 
-		clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
 
 
-		self.clf = clf_class(
-			model=self.model,
-			loss_func=self._loss_func(opts),
-			**kwargs)
+    def load_weights(self) -> None:
 
 
-		logging.info(
-			f"Wrapped the model around {clf_class.__name__}"
-			f" with kwargs: {pretty_print_dict(kwargs)}"
-		)
+        self._check_attr("clf")
+        self._check_attr("n_classes")
 
 
-	def _loss_func(self, opts):
-		if getattr(opts, "l1_loss", False):
-			return F.hinge
+        finetune, weights = self._get_loader()
 
 
-		label_smoothing = getattr(opts, "label_smoothing", 0)
-		if label_smoothing > 0:
-			assert label_smoothing < 1, "Label smoothing factor must be less than 1!"
+        self.clf.load(weights,
+            n_classes=self.n_classes,
+            finetune=finetune,
 
 
-			return partial(smoothed_cross_entropy, N=self.n_classes, eps=label_smoothing)
+            path=self._load_path,
+            strict=self._load_strict,
+            headless=self._load_headless
+        )
 
 
-		return F.softmax_cross_entropy
+        self.clf.cleargrads()
 
 
-	def init_optimizer(self, opts):
-		"""Creates an optimizer for the classifier """
-		if not hasattr(opts, "optimizer"):
-			self.opt = None
-			return
+        feat_size = self.model.meta.feature_size
 
 
-		opt_kwargs = {}
-		if opts.optimizer == "rmsprop":
-			opt_kwargs["alpha"] = 0.9
+        if hasattr(self.clf, "output_size"):
+            feat_size = self.clf.output_size
 
 
-		if opts.optimizer in ["rmsprop", "adam"]:
-			opt_kwargs["eps"] = 1e-6
+        ### TODO: handle feature size!
 
 
-		self.opt = optimizer(opts.optimizer,
-			self.clf,
-			opts.learning_rate,
-			decay=0, gradient_clipping=False, **opt_kwargs
-		)
+        logging.info(f"Part features size after encoding: {feat_size}")
 
 
-		logging.info(
-			f"Initialized {self.opt.__class__.__name__} optimizer"
-			f" with initial LR {opts.learning_rate} and kwargs: {pretty_print_dict(opt_kwargs)}"
-		)
 
 
-		if opts.decay > 0:
-			reg_kwargs = {}
-			if opts.l1_loss:
-				reg_cls = Lasso
 
 
-			elif opts.pooling == "alpha":
-				reg_cls = optimizer_hooks.SelectiveWeightDecay
-				reg_kwargs["selection"] = check_param_for_decay
+    @property
+    def prepare_type(self):
+        return PrepareType[self._prepare_type]
 
 
-			else:
-				reg_cls = WeightDecay
+    @property
+    def prepare(self):
+        return partial(self.prepare_type(self.model),
+            swap_channels=self._swap_channels,
+            keep_ratio=self._center_crop_on_val)
 
 
-			logging.info(f"Adding {reg_cls.__name__} ({opts.decay:e})")
-			self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
+    def new_model(self, **kwargs):
+        return ModelFactory.new(self.model_type,
+            input_size=self.input_size,
+            **self.model_kwargs, **kwargs)
 
 
-		if getattr(opts, "only_head", False):
-			assert not getattr(opts, "recurrent", False), \
-				"Recurrent classifier is not supported with only_head option!"
+    @property
+    def model_info(self):
+        return self.data_info.MODELS[self.model_type]
 
 
-			logging.warning("========= Fine-tuning only classifier layer! =========")
-			enable_only_head(self.clf)
 
 
 
 
-	def _get_loader(self, opts) -> Tuple[bool, str]:
-		if getattr(opts, "from_scratch", False):
-			logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
-			return None, None
+    def _get_loader(self) -> Tuple[bool, str]:
 
 
-		if getattr(opts, "load", None):
-			weights = getattr(opts, "load", None)
-			logging.info(f"Loading already fine-tuned weights from \"{weights}\"")
-			return False, weights
+        if self._from_scratch:
+            logging.info(f"Training a {type(self.model).__name__} model from scratch!")
+            return None, None
 
 
-		elif getattr(opts, "weights", None):
-			weights = getattr(opts, "weights", None)
-			logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
-			return True, weights
+        if self._load:
+            weights = self._load
+            logging.info(f"Loading already fine-tuned weights from \"{weights}\"")
+            return False, weights
 
 
-		else:
-			weights = self._default_weights(opts)
-			logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
-			return True, weights
+        elif self._weights:
+            weights = self._weights
+            logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
+            return True, weights
 
 
-	def _default_weights(self, opts):
-		if self.model_type.startswith("chainercv2"):
-			model_name = self.model_type.split(".")[-1]
-			return model_store.get_model_file(
-				model_name=model_name,
-				local_model_store_dir_path=str(Path.home() / ".chainer" / "models"))
+        else:
+            weights = self._default_weights
+            logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
+            return True, weights
 
 
-		else:
-			ds_info = self.data_info
-			model_info = self.model_info
+    @property
+    def _default_weights(self):
+        if self.model_type.startswith("chainercv2"):
+            model_name = self.model_type.split(".")[-1]
+            return model_store.get_model_file(
+                model_name=model_name,
+                local_model_store_dir_path=str(Path.home() / ".chainer" / "models"))
 
 
-			base_dir = Path(ds_info.BASE_DIR)
-			weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
+        else:
+            ds_info = self.data_info
+            model_info = self.model_info
 
 
-			weights = model_info.weights
-			assert opts.pre_training in weights, \
-				f"Weights for \"{opts.pre_training}\" pre-training were not found!"
+            base_dir = Path(ds_info.BASE_DIR)
+            weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
 
 
-			return str(weights_dir / weights[opts.pre_training])
+            weights = model_info.weights
+            assert self._pretrained_on in weights, \
+                f"Weights for \"{self._pretrained_on}\" pre-training were not found!"
 
 
+            return str(weights_dir / weights[self._pretrained_on])
 
 
-	def load_weights(self, opts) -> None:
-
-		finetune, weights = self._get_loader(opts)
-
-		self.clf.load(weights,
-			n_classes=self.n_classes,
-			finetune=finetune,
-
-			path=opts.load_path,
-			strict=opts.load_strict,
-			headless=opts.headless
-		)
-
-		self.clf.cleargrads()
-
-		feat_size = self.model.meta.feature_size
-
-		if hasattr(self.clf, "output_size"):
-			feat_size = self.clf.output_size
-
-		### TODO: handle feature size!
-
-		logging.info(f"Part features size after encoding: {feat_size}")

+ 103 - 0
cvfinetune/finetuner/mixins/optimizer.py

@@ -0,0 +1,103 @@
+import abc
+import chainer
+import logging
+
+from chainer.optimizer_hooks import Lasso
+from chainer.optimizer_hooks import WeightDecay
+from chainer_addons.training import optimizer as new_optimizer
+from chainer_addons.training.optimizer_hooks import SelectiveWeightDecay
+from cvdatasets.utils import pretty_print_dict
+
+from cvfinetune.finetuner.mixins.base import BaseMixin
+
+def check_param_for_decay(param):
+    return param.name != "alpha"
+
+def enable_only_head(chain: chainer.Chain):
+    if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
+        chain.enable_only_head()
+
+    else:
+        chain.disable_update()
+        chain.fc.enable_update()
+
+class _OptimizerCreator:
+
+    def __init__(self, opt, **kwargs):
+        super().__init__()
+
+        self.opt = opt
+        self.kwargs = kwargs
+
+    def __call__(self, *args, **kwargs):
+        if self.opt is None:
+            return None
+
+        kwargs = dict(self.kwargs, **kwargs)
+        return new_optimizer(self.opt, *args, **kwargs)
+
+class _OptimizerMixin(BaseMixin):
+
+    def __init__(self, *args,
+                 optimizer: str,
+                 learning_rate: float = 1e-3,
+                 weight_decay: float = 5e-4,
+                 eps: float = 1e-2,
+                 only_head: bool = False,
+                 **kwargs):
+
+        super().__init__(*args, **kwargs)
+
+        optimizer_kwargs = dict(decay=0, gradient_clipping=False)
+
+        if optimizer in ["rmsprop", "adam"]:
+            optimizer_kwargs["eps"] = eps
+
+        self._opt_creator = _OptimizerCreator(optimizer, **optimizer_kwargs)
+        self.learning_rate = learning_rate
+        self.weight_decay = weight_decay
+        self._only_head = only_head
+
+
+    def init_optimizer(self):
+        """Creates an optimizer for the classifier """
+
+        self._check_attr("clf")
+        self._check_attr("_pooling")
+        self._check_attr("_l1_loss")
+
+        self.opt = self._opt_creator(self.clf, self.learning_rate)
+
+        if self.opt is None:
+            logging.warning("========= No optimizer was initialized! =========")
+            return
+
+        kwargs = self._opt_creator.kwargs
+        logging.info(
+            f"Initialized {type(self.opt).__name__} optimizer"
+            f" with initial LR {self.learning_rate} and kwargs: {pretty_print_dict(kwargs)}"
+        )
+
+        self.init_regularizer()
+
+        if self._only_head:
+            logging.warning("========= Fine-tuning only classifier layer! =========")
+            enable_only_head(self.clf)
+
+    def init_regularizer(self, **kwargs):
+
+        if self.weight_decay <= 0:
+            return
+
+        if self._l1_loss:
+            cls = Lasso
+
+        elif self._pooling == "alpha":
+            cls = SelectiveWeightDecay
+            kwargs["selection"] = check_param_for_decay
+
+        else:
+            cls = WeightDecay
+
+        logging.info(f"Adding {cls.__name__} ({self.weight_decay:e})")
+        self.opt.add_hook(cls(self.weight_decay, **kwargs))

+ 113 - 81
cvfinetune/finetuner/mixins/trainer.py

@@ -1,92 +1,124 @@
 import abc
 import abc
 import logging
 import logging
 import pyaml
 import pyaml
+import gc
 
 
 from bdb import BdbQuit
 from bdb import BdbQuit
 from chainer.serializers import save_npz
 from chainer.serializers import save_npz
+from chainer.training import extension
 from chainer.training import extensions
 from chainer.training import extensions
+from chainer.training import updaters
 from cvdatasets.utils import pretty_print_dict
 from cvdatasets.utils import pretty_print_dict
 from pathlib import Path
 from pathlib import Path
 
 
-
-class _TrainerMixin(abc.ABC):
-	"""This mixin is responsible for updater, evaluator and trainer creation.
-	Furthermore, it implements the run method
-	"""
-
-	def __init__(self, opts, updater_cls, updater_kwargs={}, *args, **kwargs):
-		super(_TrainerMixin, self).__init__(*args, **kwargs)
-		self.updater_cls = updater_cls
-		self.updater_kwargs = updater_kwargs
-
-	def init_updater(self):
-		"""Creates an updater from training iterator and the optimizer."""
-
-		if self.opt is None:
-			self.updater = None
-			return
-
-		self.updater = self.updater_cls(
-			iterator=self.train_iter,
-			optimizer=self.opt,
-			device=self.device,
-			**self.updater_kwargs,
-		)
-		logging.info(" ".join([
-			f"Using single GPU: {self.device}.",
-			f"{self.updater_cls.__name__} is initialized",
-			f"with following kwargs: {pretty_print_dict(self.updater_kwargs)}"
-			])
-		)
-
-	def init_evaluator(self, default_name="val"):
-		"""Creates evaluation extension from validation iterator and the classifier."""
-
-		self.evaluator = extensions.Evaluator(
-			iterator=self.val_iter,
-			target=self.clf,
-			device=self.device,
-			progress_bar=True
-		)
-
-		self.evaluator.default_name = default_name
-
-	def _new_trainer(self, trainer_cls, opts, *args, **kwargs):
-		return trainer_cls(
-			opts=opts,
-			updater=self.updater,
-			evaluator=self.evaluator,
-			*args, **kwargs
-		)
-
-	def run(self, trainer_cls, opts, *args, **kwargs):
-
-		trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
-
-		self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
-
-		logging.info("Snapshotting is {}abled".format("dis" if opts.no_snapshot else "en"))
-
-		def dump(suffix):
-			if opts.only_eval or opts.no_snapshot:
-				return
-
-			save_npz(Path(trainer.out, f"clf_{suffix}.npz"), self.clf)
-			save_npz(Path(trainer.out, f"model_{suffix}.npz"), self.model)
-
-		try:
-			trainer.run(opts.init_eval or opts.only_eval)
-		except (KeyboardInterrupt, BdbQuit) as e:
-			raise e
-		except Exception as e:
-			dump("exception")
-			raise e
-		else:
-			dump("final")
-
-	def save_meta_info(self, opts, folder: Path):
-		folder.mkdir(parents=True, exist_ok=True)
-
-		with open(folder / "args.yml", "w") as f:
-			pyaml.dump(opts.__dict__, f, sort_keys=True)
+from cvfinetune.finetuner.mixins.base import BaseMixin
+
+@extension.make_extension(default_name="ManualGC", trigger=(1, "iteration"))
+def gc_collect(trainer):
+    gc.collect()
+
+class _TrainerMixin(BaseMixin):
+    """This mixin is responsible for updater, evaluator and trainer creation.
+    Furthermore, it implements the run method
+    """
+
+    def __init__(self, *args,
+                 updater_cls=updaters.StandardUpdater,
+                 updater_kwargs: dict = {},
+                 only_eval: bool = False,
+                 init_eval: bool = False,
+                 no_snapshot: bool = False,
+
+                 manual_gc: bool = True,
+                 **kwargs):
+        super(_TrainerMixin, self).__init__(*args, **kwargs)
+        self.updater_cls = updater_cls
+        self.updater_kwargs = updater_kwargs
+
+        self.only_eval = only_eval
+        self.init_eval = init_eval
+        self.no_snapshot = no_snapshot
+        self.manual_gc = manual_gc
+
+
+    def init_updater(self):
+        """Creates an updater from training iterator and the optimizer."""
+
+        self._check_attr("opt")
+        self._check_attr("device")
+        self._check_attr("train_iter")
+
+        if self.opt is None:
+            self.updater = None
+            return
+
+        self.updater = self.updater_cls(
+            iterator=self.train_iter,
+            optimizer=self.opt,
+            device=self.device,
+            **self.updater_kwargs,
+        )
+        logging.info(" ".join([
+            f"Using single GPU: {self.device}.",
+            f"{self.updater_cls.__name__} is initialized",
+            f"with following kwargs: {pretty_print_dict(self.updater_kwargs)}"
+            ])
+        )
+
+    def init_evaluator(self, default_name="val"):
+        """Creates evaluation extension from validation iterator and the classifier."""
+
+        self._check_attr("device")
+        self._check_attr("val_iter")
+
+        self.evaluator = extensions.Evaluator(
+            iterator=self.val_iter,
+            target=self.clf,
+            device=self.device,
+            progress_bar=True
+        )
+
+        self.evaluator.default_name = default_name
+
+    def _new_trainer(self, trainer_cls, opts, *args, **kwargs):
+        return trainer_cls(
+            opts=opts,
+            updater=self.updater,
+            evaluator=self.evaluator,
+            *args, **kwargs
+        )
+
+    def run(self, trainer_cls, opts, *args, **kwargs):
+
+        trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
+
+        if self.manual_gc:
+            trainer.extend(gc_collect)
+
+        self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
+
+        logging.info("Snapshotting is {}abled".format("dis" if self.no_snapshot else "en"))
+
+        def dump(suffix):
+            if self.only_eval or self.no_snapshot:
+                return
+
+            save_npz(Path(trainer.out, f"clf_{suffix}.npz"), self.clf)
+            save_npz(Path(trainer.out, f"model_{suffix}.npz"), self.model)
+
+        try:
+            trainer.run(self.init_eval or self.only_eval)
+        except (KeyboardInterrupt, BdbQuit) as e:
+            raise e
+        except Exception as e:
+            dump("exception")
+            raise e
+        else:
+            dump("final")
+
+    def save_meta_info(self, opts, folder: Path):
+        folder.mkdir(parents=True, exist_ok=True)
+
+        with open(folder / "args.yml", "w") as f:
+            pyaml.dump(opts.__dict__, f, sort_keys=True)
 
 

+ 6 - 6
cvfinetune/finetuner/mpi.py

@@ -8,9 +8,9 @@ from cvfinetune.finetuner.base import DefaultFinetuner
 
 
 class MPIFinetuner(DefaultFinetuner):
 class MPIFinetuner(DefaultFinetuner):
 
 
-	def __init__(self, opts, *args, comm, **kwargs):
+	def __init__(self, *args, comm, **kwargs):
 		self.comm = comm
 		self.comm = comm
-		super(MPIFinetuner, self).__init__(opts, *args, **kwargs)
+		super(MPIFinetuner, self).__init__(*args, **kwargs)
 
 
 	@property
 	@property
 	def mpi(self):
 	def mpi(self):
@@ -20,16 +20,16 @@ class MPIFinetuner(DefaultFinetuner):
 	def mpi_main_process(self):
 	def mpi_main_process(self):
 		return not (self.comm is not None and self.comm.rank != 0)
 		return not (self.comm is not None and self.comm.rank != 0)
 
 
-	def gpu_config(self, opts):
+	def gpu_config(self, devices):
 
 
 		if not self.mpi:
 		if not self.mpi:
 			msg = "Using MPIFinetuner without setting a communicator!"
 			msg = "Using MPIFinetuner without setting a communicator!"
 			warnings.warn(msg)
 			warnings.warn(msg)
 			logging.warn(msg)
 			logging.warn(msg)
-			return super(MPIFinetuner, self).gpu_config(opts)
+			return super(MPIFinetuner, self).gpu_config(devices)
 
 
-		if len(opts.gpu) > 1:
-			self.device_id = opts.gpu[self.comm.rank]
+		if len(devices) > 1:
+			self.device_id = devices[self.comm.rank]
 		else:
 		else:
 			self.device_id += self.comm.intra_rank
 			self.device_id += self.comm.intra_rank