Browse Source

refactored finetuner class mixins

Dimitri Korsch 3 years ago
parent
commit
1e900ed8e6

+ 13 - 35
cvfinetune/dataset.py

@@ -1,49 +1,27 @@
-import numpy as np
 import abc
-
-from chainer_addons.dataset import AugmentationMixin
-from chainer_addons.dataset import PreprocessMixin
+import numpy as np
 
 from cvdatasets.dataset import AnnotationsReadMixin
-from cvdatasets.dataset import RevealedPartMixin
 from cvdatasets.dataset import IteratorMixin
+from cvdatasets.dataset import TransformMixin
+from cvdatasets.dataset import UniformPartMixin
 
-class _pre_augmentation_mixin(abc.ABC):
-	""" This mixin discards the parts from the ImageWrapper object
-	and shifts the labels
-	"""
-
-	label_shift = 1
-
-	def get_example(self, i):
-		im_obj = super(_pre_augmentation_mixin, self).get_example(i)
-		im, parts, lab = im_obj.as_tuple()
-		return im, lab + self.label_shift
-
-class _base_mixin(abc.ABC):
-	""" This mixin converts images,that are in range
-	[0..1] to the range [-1..1]
-	"""
+class BaseDataset(TransformMixin, UniformPartMixin, AnnotationsReadMixin):
+	"""Commonly used dataset constellation"""
 
-	def get_example(self, i):
-		im, lab = super(_base_mixin, self).get_example(i)
+	def __init__(self, *args, prepare, center_crop_on_val: bool = True, **kwargs):
+		super().__init__(*args, **kwargs)
+		self.prepare = prepare
 
+	def augment(self, im):
 		if isinstance(im, list):
 			im = np.array(im)
 
 		if np.logical_and(0 <= im, im <= 1).all():
 			im = im * 2 -1
 
-		return im, lab
-
+		return im
 
-class BaseDataset(_base_mixin,
-	# augmentation and preprocessing
-	AugmentationMixin, PreprocessMixin,
-	_pre_augmentation_mixin,
-	# random uniform region selection
-	RevealedPartMixin,
-	# reads image
-	AnnotationsReadMixin,
-	IteratorMixin):
-	"""Commonly used dataset constellation"""
+	def transform(self, im_obj):
+		im, parts, lab = im_obj.as_tuple()
+		return self.prepare(im), lab + self.label_shift

+ 27 - 16
cvfinetune/finetuner/base.py

@@ -3,41 +3,52 @@ import logging
 
 from cvfinetune.finetuner import mixins
 
-class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._TrainerMixin):
+class DefaultFinetuner(
+	mixins._ModelMixin,
+	mixins._OptimizerMixin,
+	mixins._ClassifierMixin,
+	mixins._DatasetMixin,
+	mixins._IteratorMixin,
+	mixins._TrainerMixin):
 	""" The default Finetuner gathers together the creations of all needed
 	components and call them in the correct order
 
 	"""
 
-	def __init__(self, opts, *args, **kwargs):
-		super(DefaultFinetuner, self).__init__(opts=opts, *args, **kwargs)
+	def __init__(self, *args, gpu = [-1], **kwargs):
+		super().__init__(*args, **kwargs)
 
-		self.gpu_config(opts)
-		self.read_annotations(opts)
+		self.gpu_config(gpu)
+		self.read_annotations()
 
-		self.init_model(opts)
-		self.init_datasets(opts)
-		self.init_iterators(opts)
+		self.init_model()
+		self.init_datasets()
+		self.init_iterators()
 
-		self.init_classifier(opts)
-		self.load_weights(opts)
+		self.init_classifier()
+		self.load_weights()
 
-		self.init_optimizer(opts)
+		self.init_optimizer()
 		self.init_updater()
 		self.init_evaluator()
 
+
+	def _check_attr(self, attr_name, msg=None):
+		msg = msg or f"<{type(self).__name__}> {attr_name} attribute was not initialized!"
+		assert hasattr(self, attr_name), msg
+
 	def init_device(self):
 		self.device = chainer.get_device(self.device_id)
 		self.device.use()
 		return self.device
 
-
-	def gpu_config(self, opts):
-		if -1 in opts.gpu:
+	def gpu_config(self, devices):
+		if -1 in devices:
 			self.device_id = -1
 		else:
-			self.device_id = opts.gpu[0]
+			self.device_id = devices[0]
 
+		device = self.init_device()
 		logging.info(f"Using device {device}")
-		return self.init_device()
+		return device
 

+ 33 - 29
cvfinetune/finetuner/factory.py

@@ -1,10 +1,10 @@
 import logging
 try:
-	import chainermn
+    import chainermn
 except Exception as e: #pragma: no cover
-	_CHAINERMN_AVAILABLE = False #pragma: no cover
+    _CHAINERMN_AVAILABLE = False #pragma: no cover
 else:
-	_CHAINERMN_AVAILABLE = True
+    _CHAINERMN_AVAILABLE = True
 
 from cvfinetune import utils
 from cvfinetune.finetuner.base import DefaultFinetuner
@@ -14,37 +14,41 @@ from cvdatasets.utils import pretty_print_dict
 
 class FinetunerFactory(object):
 
-	@classmethod
-	def new(cls, opts, default=DefaultFinetuner, mpi_tuner=MPIFinetuner):
+    @classmethod
+    def new(cls, *,
+            mpi: bool = False,
+            default=DefaultFinetuner,
+            mpi_tuner=MPIFinetuner,
+            **kwargs):
 
-		if getattr(opts, "mpi", False):
-			assert _CHAINERMN_AVAILABLE, "Distributed training is not possible!"
-			msg1 = "MPI enabled. Creating NCCL communicator!"
-			comm = chainermn.create_communicator("pure_nccl")
-			msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
+        if mpi:
+            assert _CHAINERMN_AVAILABLE, "Distributed training is not possible!"
+            msg1 = "MPI enabled. Creating NCCL communicator!"
+            comm = chainermn.create_communicator("pure_nccl")
+            msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
 
-			utils.log_messages([msg1, msg2])
-			return cls(mpi_tuner, comm=comm)
-		else:
-			return cls(default)
+            utils.log_messages([msg1, msg2])
+            return cls(mpi_tuner, comm=comm, **kwargs)
+        else:
+            return cls(default, **kwargs)
 
-	def __init__(self, tuner_cls, **kwargs):
-		super(FinetunerFactory, self).__init__()
+    def __init__(self, tuner_cls, **kwargs):
+        super(FinetunerFactory, self).__init__()
 
-		self.tuner_cls = tuner_cls
-		self.kwargs = kwargs
-		logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
+        self.tuner_cls = tuner_cls
+        self.kwargs = kwargs
+        logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
 
-	def __call__(self, **kwargs):
-		_kwargs = dict(self.kwargs, **kwargs)
+    def __call__(self, opts, **kwargs):
+        opt_kwargs = self.tuner_cls.extract_kwargs(opts)
+        _kwargs = dict(self.kwargs, **kwargs, **opt_kwargs)
+        return self.tuner_cls(**_kwargs)
 
-		return self.tuner_cls(**_kwargs)
+    def get(self, key, default=None):
+        return self.kwargs.get(key, default)
 
-	def get(self, key, default=None):
-		return self.kwargs.get(key, default)
+    def __getitem__(self, key):
+        return self.kwargs[key]
 
-	def __getitem__(self, key):
-		return self.kwargs[key]
-
-	def __setitem__(self, key, value):
-		self.kwargs[key] = value
+    def __setitem__(self, key, value):
+        self.kwargs[key] = value

+ 6 - 0
cvfinetune/finetuner/mixins/__init__.py

@@ -1,10 +1,16 @@
 from cvfinetune.finetuner.mixins.dataset import _DatasetMixin
+from cvfinetune.finetuner.mixins.classifier import _ClassifierMixin
 from cvfinetune.finetuner.mixins.model import _ModelMixin
+from cvfinetune.finetuner.mixins.optimizer import _OptimizerMixin
+from cvfinetune.finetuner.mixins.iterator import _IteratorMixin
 from cvfinetune.finetuner.mixins.trainer import _TrainerMixin
 
 
 __all__ = [
 	"_DatasetMixin",
+	"_ClassifierMixin",
 	"_ModelMixin",
+	"_OptimizerMixin",
+	"_IteratorMixin",
 	"_TrainerMixin",
 ]

+ 71 - 0
cvfinetune/finetuner/mixins/base.py

@@ -0,0 +1,71 @@
+import abc
+import inspect
+
+
+class BaseMixin(abc.ABC):
+
+	def _after_init_check(self):
+		pass
+
+	@classmethod
+	def extract_kwargs(cls, opts) -> dict:
+
+		kwargs = {}
+
+		for klass in cls.mro():
+			sig = inspect.signature(klass.__init__)
+			for attr, param in sig.parameters.items():
+				if param.kind is not inspect.Parameter.KEYWORD_ONLY:
+					continue
+
+				if param.name in kwargs:
+					continue
+
+				if hasattr(opts, param.name):
+					value = getattr(opts, param.name)
+					kwargs[param.name] = value
+		return kwargs
+
+
+if __name__ == '__main__':
+
+	from collections import namedtuple
+	class Foo(BaseMixin):
+
+		@classmethod
+		def extract_kwargs(cls, opts) -> dict:
+			return super().extract_kwargs(opts)
+
+		def __init__(self, *args, foo, bar=0, **kwargs):
+			super().__init__(*args, **kwargs)
+			self.foo = foo
+			self.bar = bar
+
+
+	class Bar(BaseMixin):
+		@classmethod
+		def extract_kwargs(cls, opts) -> dict:
+			return super().extract_kwargs(opts)
+
+		def __init__(self, *args, bar2=-1, **kwargs):
+			super().__init__(*args, **kwargs)
+			self.bar2 = bar2
+
+
+	class Final(Bar, Foo):
+
+
+		def __init__(self, *args, beef=-1, **kwargs):
+			super().__init__(*args, **kwargs)
+			self.beef = beef
+
+		def __repr__(self):
+			return str(self.__dict__)
+
+
+	Opts = namedtuple("Opts", "foo foo2 bar bar2 beef1")
+
+	opts = Opts(1,2,3, -4, "hat")
+	kwargs = Final.extract_kwargs(opts)
+
+	print(opts, Final(**kwargs))

+ 67 - 0
cvfinetune/finetuner/mixins/classifier.py

@@ -0,0 +1,67 @@
+import abc
+import logging
+
+from chainer import functions as F
+from chainer_addons.functions import smoothed_cross_entropy
+from cvdatasets.utils import pretty_print_dict
+from functools import partial
+
+from cvfinetune.finetuner.mixins.base import BaseMixin
+
+class _ClassifierCreator:
+
+    def __init__(self, cls, **kwargs):
+        super().__init__()
+        self.cls = cls
+        self.kwargs = kwargs
+
+    def __call__(self, *args, **kwargs):
+        kwargs = dict(self.kwargs, **kwargs)
+        return self.cls(*args, **kwargs)
+
+class _ClassifierMixin(BaseMixin):
+    """
+        This mixin implements the wrapping of the backbone model around
+        a classifier instance.
+    """
+
+    def __init__(self, *args,
+                 classifier_cls,
+                 classifier_kwargs: dict = {},
+                 l1_loss: bool = False,
+                 label_smoothing: float = 0.0,
+                 **kwargs):
+
+        super().__init__(*args, **kwargs)
+        self._clf_creator = _ClassifierCreator(classifier_cls, **classifier_kwargs)
+
+        self._l1_loss = l1_loss
+        self._label_smoothing = label_smoothing
+
+
+    def init_classifier(self):
+        self._check_attr("model")
+        self._check_attr("n_classes")
+
+        self.clf = self._clf_creator(model=self.model,
+                                     loss_func=self.loss_func)
+
+        kwargs = self._clf_creator.kwargs
+        logging.info(
+            f"Wrapped the model around {type(self.clf).__name__}"
+            f" with kwargs: {pretty_print_dict(kwargs)}"
+        )
+
+    @property
+    def loss_func(self):
+        if self._l1_loss:
+            return F.hinge
+
+        if self._label_smoothing > 0:
+            assert self._label_smoothing < 1, "Label smoothing factor must be less than 1!"
+
+            return partial(smoothed_cross_entropy,
+                           N=self.n_classes,
+                           eps=self._label_smoothing)
+
+        return F.softmax_cross_entropy

+ 73 - 68
cvfinetune/finetuner/mixins/dataset.py

@@ -1,96 +1,101 @@
 import abc
 import logging
+import typing as T
 
+from collections import namedtuple
 from cvdatasets import AnnotationType
 from cvdatasets.dataset.image import Size
-from cvdatasets.utils import new_iterator
 
+from cvfinetune.finetuner.mixins.base import BaseMixin
 
-class _DatasetMixin(abc.ABC):
-	"""
-		This mixin is responsible for annotation loading and for
-		dataset and iterator creation.
-	"""
 
-	def __init__(self, opts, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
-		super(_DatasetMixin, self).__init__(opts=opts, *args, **kwargs)
-		self.annot = None
-		self.dataset_type = opts.dataset
-		self.dataset_cls = dataset_cls
-		self.dataset_kwargs_factory = dataset_kwargs_factory
+class _DatasetMixin(BaseMixin):
+    """
+        This mixin is responsible for annotation loading and for
+        dataset and iterator creation.
+    """
 
-	@property
-	def n_classes(self):
-		return self.ds_info.n_classes + self.dataset_cls.label_shift
+    def __init__(self,
+                 *args,
+                 data: str,
+                 dataset: str,
+                 dataset_cls: T.Type,
+                 dataset_kwargs_factory: T.Optional[T.Callable] = None,
 
-	@property
-	def data_info(self):
-		assert self.annot is not None, "annot attribute was not set!"
-		return self.annot.info
+                 label_shift: int = 0,
+                 input_size: int = 224,
+                 part_input_size:  T.Optional[int] = None,
+                 **kwargs):
 
-	@property
-	def ds_info(self):
-		return self.data_info.DATASETS[self.dataset_type]
+        super().__init__(*args, **kwargs)
+        self.annot = None
+        self.info_file = data
+        self.dataset_name = dataset
+        self.dataset_cls = dataset_cls
+        self.dataset_kwargs_factory = dataset_kwargs_factory
 
-	def new_dataset(self, opts, size, part_size, subset):
-		"""Creates a dataset for a specific subset and certain options"""
-		if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
-			kwargs = self.dataset_kwargs_factory(opts, subset)
-		else:
-			kwargs = dict()
+        self.input_size = Size(input_size)
 
-		kwargs = dict(kwargs,
-			subset=subset,
-			dataset_cls=self.dataset_cls,
-			prepare=self.prepare,
-			size=size,
-			part_size=part_size,
-			center_crop_on_val=getattr(opts, "center_crop_on_val", False),
-		)
+        if part_input_size is None:
+            self.part_input_size = self.input_size
 
+        else:
+            self.part_input_size = Size(self.part_input_size)
 
-		ds = self.annot.new_dataset(**kwargs)
-		logging.info("Loaded {} images".format(len(ds)))
-		return ds
+        self._label_shift = label_shift
 
 
-	def read_annotations(self, opts):
-		"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
+    def read_annotations(self):
+        """Reads annotations and creates annotation instance, which holds important infos about the dataset"""
+        opts = namedtuple("Opt", "data dataset")(self.info_file, self.dataset_name)
+        self.annot = AnnotationType.new_annotation(opts, load_strict=False)
+        self.dataset_cls.label_shift = self._label_shift
 
-		self.annot = AnnotationType.new_annotation(opts, load_strict=False)
-		self.dataset_cls.label_shift = opts.label_shift
+    def init_datasets(self):
+        self._check_attr("prepare")
+        self._check_attr("_center_crop_on_val")
 
+        logging.info(" | ".join([
+            f"Image input size: {self.input_size}",
+            f"Parts input size: {self.part_input_size}",
+        ]))
 
-	def init_datasets(self, opts):
+        self.train_data = self.new_dataset("train")
+        self.val_data = self.new_dataset("test")
 
-		size = Size(opts.input_size)
-		part_size = getattr(opts, "parts_input_size", None)
-		part_size = size if part_size is None else Size(part_size)
 
-		logging.info(" ".join([
-			f"Image input size: {size}",
-			f"Image parts input size: {part_size}",
-		]))
+    @property
+    def n_classes(self):
+        return self.ds_info.n_classes + self._label_shift
 
-		self.train_data = self.new_dataset(opts, size, part_size, "train")
-		self.val_data = self.new_dataset(opts, size, part_size, "test")
+    @property
+    def data_info(self):
+        assert self.annot is not None, "annot attribute was not set!"
+        return self.annot.info
 
-	def init_iterators(self, opts):
-		"""Creates training and validation iterators from training and validation datasets"""
+    @property
+    def ds_info(self):
+        return self.data_info.DATASETS[self.dataset_name]
 
-		kwargs = dict(n_jobs=opts.n_jobs, batch_size=opts.batch_size)
+    def new_dataset(self, subset: str):
+        """Creates a dataset for a specific subset and certain options"""
+        if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
+            kwargs = self.dataset_kwargs_factory(subset)
+        else:
+            kwargs = dict()
 
-		if hasattr(self.train_data, "new_iterator"):
-			self.train_iter, _ = self.train_data.new_iterator(**kwargs)
-		else:
-			self.train_iter, _ = new_iterator(self.train_data, **kwargs)
+        kwargs = dict(kwargs,
+            subset=subset,
+            dataset_cls=self.dataset_cls,
+            prepare=self.prepare,
+            size=self.input_size,
+            part_size=self.part_input_size,
+            center_crop_on_val=self._center_crop_on_val,
+        )
+
+
+        ds = self.annot.new_dataset(**kwargs)
+        logging.info(f"Loaded {len(ds)} images")
+        return ds
 
-		if hasattr(self.val_data, "new_iterator"):
-			self.val_iter, _ = self.val_data.new_iterator(**kwargs,
-				repeat=False, shuffle=False
-			)
-		else:
-			self.val_iter, _ = new_iterator(self.val_data,
-				**kwargs, repeat=False, shuffle=False
-			)
 

+ 41 - 0
cvfinetune/finetuner/mixins/iterator.py

@@ -0,0 +1,41 @@
+import abc
+import logging
+import typing as T
+
+from cvdatasets.utils import new_iterator
+
+from cvfinetune.finetuner.mixins.base import BaseMixin
+
+class _IteratorMixin(BaseMixin):
+
+    def __init__(self,
+                 *args,
+                 batch_size: int = 32,
+                 n_jobs: int = 1,
+                 **kwargs):
+    	super().__init__(*args, **kwargs)
+
+    	self._batch_size = batch_size
+    	self._n_jobs = n_jobs
+
+
+    def new_iterator(self, ds, **kwargs):
+    	if hasattr(ds, "new_iterator"):
+    		return ds.new_iterator(**kwargs)
+    	else:
+    		return new_iterator(ds, **kwargs)
+
+    def init_iterators(self):
+        """Creates training and validation iterators from training and validation datasets"""
+
+        self._check_attr("val_data")
+        self._check_attr("train_data")
+
+        kwargs = dict(n_jobs=self._n_jobs, batch_size=self._batch_size)
+
+        self.train_iter, _ = self.new_iterator(self.train_data,
+        	                                   **kwargs)
+
+        self.val_iter, _ = self.new_iterator(self.val_data,
+        	                                 repeat=False, shuffle=False,
+        	                                 **kwargs)

+ 113 - 151
cvfinetune/finetuner/mixins/model.py

@@ -3,207 +3,169 @@ import chainer
 import logging
 
 from chainer import functions as F
-from chainer.optimizer_hooks import Lasso
-from chainer.optimizer_hooks import WeightDecay
-from chainer_addons.functions import smoothed_cross_entropy
 from chainer_addons.models import PrepareType
-from chainer_addons.training import optimizer
-from chainer_addons.training import optimizer_hooks
 from chainercv2.models import model_store
 from cvdatasets.dataset.image import Size
-from cvdatasets.utils import pretty_print_dict
 from cvmodelz.models import ModelFactory
 from functools import partial
 from pathlib import Path
 from typing import Tuple
 
-def check_param_for_decay(param):
-	return param.name != "alpha"
+from cvfinetune.finetuner.mixins.base import BaseMixin
 
-def enable_only_head(chain: chainer.Chain):
-	if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
-		chain.enable_only_head()
+class _ModelMixin(BaseMixin):
+    """
+        This mixin is responsible for model selection, model and optimizer creation,
+        and model weights loading.
+    """
 
-	else:
-		chain.disable_update()
-		chain.fc.enable_update()
+    def __init__(self, *args,
+                 model_type: str,
+                 model_kwargs: dict = {},
+                 pooling: str = "g_avg",
 
+                 prepare_type: str = "model",
+                 center_crop_on_val: bool = True,
+                 swap_channels: bool = False,
 
-class _ModelMixin(abc.ABC):
-	"""
-		This mixin is responsible for optimizer creation, model creation,
-		model wrapping around a classifier and model weights loading.
-	"""
+                 load: str = None,
+                 weights: str = None,
+                 load_path: str = "",
+                 load_strict: bool = False,
+                 load_headless: bool = False,
+                 pretrained_on: str = "imagenet",
 
-	def __init__(self, opts, classifier_cls, classifier_kwargs={}, model_kwargs={}, *args, **kwargs):
-		super(_ModelMixin, self).__init__(opts=opts, *args, **kwargs)
-		self.classifier_cls = classifier_cls
-		self.classifier_kwargs = classifier_kwargs
-		self.model_type = opts.model_type
-		self.model_kwargs = model_kwargs
+                 from_scratch: bool = False,
+                 **kwargs):
+        super().__init__(*args, **kwargs)
 
+        self.model_type = model_type
+        self.model_kwargs = model_kwargs
 
-	@property
-	def model_info(self):
-		return self.data_info.MODELS[self.model_type]
+        self._center_crop_on_val = center_crop_on_val
+        self._swap_channels = swap_channels
 
-	def init_model(self, opts):
-		"""creates backbone CNN model. This model is wrapped around the classifier later"""
+        if model_type.startswith("chainercv2"):
+            if prepare_type != "chainercv2":
+                msg = f"Using chainercv2 model, but prepare_type was set to \"{prepare_type}\". "
+                "Setting it to \"chainercv2\"!"
+                warnings.warn(msg)
+            prepare_type = "chainercv2"
 
-		self.model = ModelFactory.new(self.model_type,
-			input_size=Size(opts.input_size),
-			**self.model_kwargs
-		)
+        self._prepare_type = prepare_type
+        self._pooling = pooling
 
+        self._load = load
+        self._weights = weights
+        self._from_scratch = from_scratch
+        self._load_path = load_path
+        self._load_strict = load_strict
+        self._load_headless = load_headless
+        self._pretrained_on = pretrained_on
 
-		if self.model_type.startswith("chainercv2"):
-			opts.prepare_type = "chainercv2"
 
-		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
-			swap_channels=opts.swap_channels,
-			keep_ratio=getattr(opts, "center_crop_on_val", False),
-		)
+    def init_model(self):
+        """creates backbone CNN model. This model is wrapped around the classifier later"""
 
-		logging.info(
-			f"Created {self.model.__class__.__name__} model "
-			f" with \"{opts.prepare_type}\" prepare function."
-		)
+        self._check_attr("input_size")
 
+        self.model = self.new_model()
 
-	def init_classifier(self, opts):
+        logging.info(
+            f"Created {type(self.model).__name__} model "
+            f" with \"{self._prepare_type}\" prepare function."
+        )
 
-		clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
 
-		self.clf = clf_class(
-			model=self.model,
-			loss_func=self._loss_func(opts),
-			**kwargs)
+    def load_weights(self) -> None:
 
-		logging.info(
-			f"Wrapped the model around {clf_class.__name__}"
-			f" with kwargs: {pretty_print_dict(kwargs)}"
-		)
+        self._check_attr("clf")
+        self._check_attr("n_classes")
 
-	def _loss_func(self, opts):
-		if getattr(opts, "l1_loss", False):
-			return F.hinge
+        finetune, weights = self._get_loader()
 
-		label_smoothing = getattr(opts, "label_smoothing", 0)
-		if label_smoothing > 0:
-			assert label_smoothing < 1, "Label smoothing factor must be less than 1!"
+        self.clf.load(weights,
+            n_classes=self.n_classes,
+            finetune=finetune,
 
-			return partial(smoothed_cross_entropy, N=self.n_classes, eps=label_smoothing)
+            path=self._load_path,
+            strict=self._load_strict,
+            headless=self._load_headless
+        )
 
-		return F.softmax_cross_entropy
+        self.clf.cleargrads()
 
-	def init_optimizer(self, opts):
-		"""Creates an optimizer for the classifier """
-		if not hasattr(opts, "optimizer"):
-			self.opt = None
-			return
+        feat_size = self.model.meta.feature_size
 
-		opt_kwargs = {}
-		if opts.optimizer == "rmsprop":
-			opt_kwargs["alpha"] = 0.9
+        if hasattr(self.clf, "output_size"):
+            feat_size = self.clf.output_size
 
-		if opts.optimizer in ["rmsprop", "adam"]:
-			opt_kwargs["eps"] = 1e-6
+        ### TODO: handle feature size!
 
-		self.opt = optimizer(opts.optimizer,
-			self.clf,
-			opts.learning_rate,
-			decay=0, gradient_clipping=False, **opt_kwargs
-		)
+        logging.info(f"Part features size after encoding: {feat_size}")
 
-		logging.info(
-			f"Initialized {self.opt.__class__.__name__} optimizer"
-			f" with initial LR {opts.learning_rate} and kwargs: {pretty_print_dict(opt_kwargs)}"
-		)
 
-		if opts.decay > 0:
-			reg_kwargs = {}
-			if opts.l1_loss:
-				reg_cls = Lasso
 
-			elif opts.pooling == "alpha":
-				reg_cls = optimizer_hooks.SelectiveWeightDecay
-				reg_kwargs["selection"] = check_param_for_decay
+    @property
+    def prepare_type(self):
+        return PrepareType[self._prepare_type]
 
-			else:
-				reg_cls = WeightDecay
+    @property
+    def prepare(self):
+        return partial(self.prepare_type(self.model),
+            swap_channels=self._swap_channels,
+            keep_ratio=self._center_crop_on_val)
 
-			logging.info(f"Adding {reg_cls.__name__} ({opts.decay:e})")
-			self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
+    def new_model(self, **kwargs):
+        return ModelFactory.new(self.model_type,
+            input_size=self.input_size,
+            **self.model_kwargs, **kwargs)
 
-		if getattr(opts, "only_head", False):
-			assert not getattr(opts, "recurrent", False), \
-				"Recurrent classifier is not supported with only_head option!"
+    @property
+    def model_info(self):
+        return self.data_info.MODELS[self.model_type]
 
-			logging.warning("========= Fine-tuning only classifier layer! =========")
-			enable_only_head(self.clf)
 
 
-	def _get_loader(self, opts) -> Tuple[bool, str]:
-		if getattr(opts, "from_scratch", False):
-			logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
-			return None, None
+    def _get_loader(self) -> Tuple[bool, str]:
 
-		if getattr(opts, "load", None):
-			weights = getattr(opts, "load", None)
-			logging.info(f"Loading already fine-tuned weights from \"{weights}\"")
-			return False, weights
+        if self._from_scratch:
+            logging.info(f"Training a {type(self.model).__name__} model from scratch!")
+            return None, None
 
-		elif getattr(opts, "weights", None):
-			weights = getattr(opts, "weights", None)
-			logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
-			return True, weights
+        if self._load:
+            weights = self._load
+            logging.info(f"Loading already fine-tuned weights from \"{weights}\"")
+            return False, weights
 
-		else:
-			weights = self._default_weights(opts)
-			logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
-			return True, weights
+        elif self._weights:
+            weights = self._weights
+            logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
+            return True, weights
 
-	def _default_weights(self, opts):
-		if self.model_type.startswith("chainercv2"):
-			model_name = self.model_type.split(".")[-1]
-			return model_store.get_model_file(
-				model_name=model_name,
-				local_model_store_dir_path=str(Path.home() / ".chainer" / "models"))
+        else:
+            weights = self._default_weights
+            logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
+            return True, weights
 
-		else:
-			ds_info = self.data_info
-			model_info = self.model_info
+    @property
+    def _default_weights(self):
+        if self.model_type.startswith("chainercv2"):
+            model_name = self.model_type.split(".")[-1]
+            return model_store.get_model_file(
+                model_name=model_name,
+                local_model_store_dir_path=str(Path.home() / ".chainer" / "models"))
 
-			base_dir = Path(ds_info.BASE_DIR)
-			weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
+        else:
+            ds_info = self.data_info
+            model_info = self.model_info
 
-			weights = model_info.weights
-			assert opts.pre_training in weights, \
-				f"Weights for \"{opts.pre_training}\" pre-training were not found!"
+            base_dir = Path(ds_info.BASE_DIR)
+            weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
 
-			return str(weights_dir / weights[opts.pre_training])
+            weights = model_info.weights
+            assert self._pretrained_on in weights, \
+                f"Weights for \"{self._pretrained_on}\" pre-training were not found!"
 
+            return str(weights_dir / weights[self._pretrained_on])
 
-	def load_weights(self, opts) -> None:
-
-		finetune, weights = self._get_loader(opts)
-
-		self.clf.load(weights,
-			n_classes=self.n_classes,
-			finetune=finetune,
-
-			path=opts.load_path,
-			strict=opts.load_strict,
-			headless=opts.headless
-		)
-
-		self.clf.cleargrads()
-
-		feat_size = self.model.meta.feature_size
-
-		if hasattr(self.clf, "output_size"):
-			feat_size = self.clf.output_size
-
-		### TODO: handle feature size!
-
-		logging.info(f"Part features size after encoding: {feat_size}")

+ 103 - 0
cvfinetune/finetuner/mixins/optimizer.py

@@ -0,0 +1,103 @@
+import abc
+import chainer
+import logging
+
+from chainer.optimizer_hooks import Lasso
+from chainer.optimizer_hooks import WeightDecay
+from chainer_addons.training import optimizer as new_optimizer
+from chainer_addons.training.optimizer_hooks import SelectiveWeightDecay
+from cvdatasets.utils import pretty_print_dict
+
+from cvfinetune.finetuner.mixins.base import BaseMixin
+
+def check_param_for_decay(param):
+    return param.name != "alpha"
+
+def enable_only_head(chain: chainer.Chain):
+    if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
+        chain.enable_only_head()
+
+    else:
+        chain.disable_update()
+        chain.fc.enable_update()
+
+class _OptimizerCreator:
+
+    def __init__(self, opt, **kwargs):
+        super().__init__()
+
+        self.opt = opt
+        self.kwargs = kwargs
+
+    def __call__(self, *args, **kwargs):
+        if self.opt is None:
+            return None
+
+        kwargs = dict(self.kwargs, **kwargs)
+        return new_optimizer(self.opt, *args, **kwargs)
+
+class _OptimizerMixin(BaseMixin):
+
+    def __init__(self, *args,
+                 optimizer: str,
+                 learning_rate: float = 1e-3,
+                 weight_decay: float = 5e-4,
+                 eps: float = 1e-2,
+                 only_head: bool = False,
+                 **kwargs):
+
+        super().__init__(*args, **kwargs)
+
+        optimizer_kwargs = dict(decay=0, gradient_clipping=False)
+
+        if optimizer in ["rmsprop", "adam"]:
+            optimizer_kwargs["eps"] = eps
+
+        self._opt_creator = _OptimizerCreator(optimizer, **optimizer_kwargs)
+        self.learning_rate = learning_rate
+        self.weight_decay = weight_decay
+        self._only_head = only_head
+
+
+    def init_optimizer(self):
+        """Creates an optimizer for the classifier """
+
+        self._check_attr("clf")
+        self._check_attr("_pooling")
+        self._check_attr("_l1_loss")
+
+        self.opt = self._opt_creator(self.clf, self.learning_rate)
+
+        if self.opt is None:
+            logging.warning("========= No optimizer was initialized! =========")
+            return
+
+        kwargs = self._opt_creator.kwargs
+        logging.info(
+            f"Initialized {type(self.opt).__name__} optimizer"
+            f" with initial LR {self.learning_rate} and kwargs: {pretty_print_dict(kwargs)}"
+        )
+
+        self.init_regularizer()
+
+        if self._only_head:
+            logging.warning("========= Fine-tuning only classifier layer! =========")
+            enable_only_head(self.clf)
+
+    def init_regularizer(self, **kwargs):
+
+        if self.weight_decay <= 0:
+            return
+
+        if self._l1_loss:
+            cls = Lasso
+
+        elif self._pooling == "alpha":
+            cls = SelectiveWeightDecay
+            kwargs["selection"] = check_param_for_decay
+
+        else:
+            cls = WeightDecay
+
+        logging.info(f"Adding {cls.__name__} ({self.weight_decay:e})")
+        self.opt.add_hook(cls(self.weight_decay, **kwargs))

+ 113 - 81
cvfinetune/finetuner/mixins/trainer.py

@@ -1,92 +1,124 @@
 import abc
 import logging
 import pyaml
+import gc
 
 from bdb import BdbQuit
 from chainer.serializers import save_npz
+from chainer.training import extension
 from chainer.training import extensions
+from chainer.training import updaters
 from cvdatasets.utils import pretty_print_dict
 from pathlib import Path
 
-
-class _TrainerMixin(abc.ABC):
-	"""This mixin is responsible for updater, evaluator and trainer creation.
-	Furthermore, it implements the run method
-	"""
-
-	def __init__(self, opts, updater_cls, updater_kwargs={}, *args, **kwargs):
-		super(_TrainerMixin, self).__init__(*args, **kwargs)
-		self.updater_cls = updater_cls
-		self.updater_kwargs = updater_kwargs
-
-	def init_updater(self):
-		"""Creates an updater from training iterator and the optimizer."""
-
-		if self.opt is None:
-			self.updater = None
-			return
-
-		self.updater = self.updater_cls(
-			iterator=self.train_iter,
-			optimizer=self.opt,
-			device=self.device,
-			**self.updater_kwargs,
-		)
-		logging.info(" ".join([
-			f"Using single GPU: {self.device}.",
-			f"{self.updater_cls.__name__} is initialized",
-			f"with following kwargs: {pretty_print_dict(self.updater_kwargs)}"
-			])
-		)
-
-	def init_evaluator(self, default_name="val"):
-		"""Creates evaluation extension from validation iterator and the classifier."""
-
-		self.evaluator = extensions.Evaluator(
-			iterator=self.val_iter,
-			target=self.clf,
-			device=self.device,
-			progress_bar=True
-		)
-
-		self.evaluator.default_name = default_name
-
-	def _new_trainer(self, trainer_cls, opts, *args, **kwargs):
-		return trainer_cls(
-			opts=opts,
-			updater=self.updater,
-			evaluator=self.evaluator,
-			*args, **kwargs
-		)
-
-	def run(self, trainer_cls, opts, *args, **kwargs):
-
-		trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
-
-		self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
-
-		logging.info("Snapshotting is {}abled".format("dis" if opts.no_snapshot else "en"))
-
-		def dump(suffix):
-			if opts.only_eval or opts.no_snapshot:
-				return
-
-			save_npz(Path(trainer.out, f"clf_{suffix}.npz"), self.clf)
-			save_npz(Path(trainer.out, f"model_{suffix}.npz"), self.model)
-
-		try:
-			trainer.run(opts.init_eval or opts.only_eval)
-		except (KeyboardInterrupt, BdbQuit) as e:
-			raise e
-		except Exception as e:
-			dump("exception")
-			raise e
-		else:
-			dump("final")
-
-	def save_meta_info(self, opts, folder: Path):
-		folder.mkdir(parents=True, exist_ok=True)
-
-		with open(folder / "args.yml", "w") as f:
-			pyaml.dump(opts.__dict__, f, sort_keys=True)
+from cvfinetune.finetuner.mixins.base import BaseMixin
+
+@extension.make_extension(default_name="ManualGC", trigger=(1, "iteration"))
+def gc_collect(trainer):
+    gc.collect()
+
+class _TrainerMixin(BaseMixin):
+    """This mixin is responsible for updater, evaluator and trainer creation.
+    Furthermore, it implements the run method
+    """
+
+    def __init__(self, *args,
+                 updater_cls=updaters.StandardUpdater,
+                 updater_kwargs: dict = {},
+                 only_eval: bool = False,
+                 init_eval: bool = False,
+                 no_snapshot: bool = False,
+
+                 manual_gc: bool = True,
+                 **kwargs):
+        super(_TrainerMixin, self).__init__(*args, **kwargs)
+        self.updater_cls = updater_cls
+        self.updater_kwargs = updater_kwargs
+
+        self.only_eval = only_eval
+        self.init_eval = init_eval
+        self.no_snapshot = no_snapshot
+        self.manual_gc = manual_gc
+
+
+    def init_updater(self):
+        """Creates an updater from training iterator and the optimizer."""
+
+        self._check_attr("opt")
+        self._check_attr("device")
+        self._check_attr("train_iter")
+
+        if self.opt is None:
+            self.updater = None
+            return
+
+        self.updater = self.updater_cls(
+            iterator=self.train_iter,
+            optimizer=self.opt,
+            device=self.device,
+            **self.updater_kwargs,
+        )
+        logging.info(" ".join([
+            f"Using single GPU: {self.device}.",
+            f"{self.updater_cls.__name__} is initialized",
+            f"with following kwargs: {pretty_print_dict(self.updater_kwargs)}"
+            ])
+        )
+
+    def init_evaluator(self, default_name="val"):
+        """Creates evaluation extension from validation iterator and the classifier."""
+
+        self._check_attr("device")
+        self._check_attr("val_iter")
+
+        self.evaluator = extensions.Evaluator(
+            iterator=self.val_iter,
+            target=self.clf,
+            device=self.device,
+            progress_bar=True
+        )
+
+        self.evaluator.default_name = default_name
+
+    def _new_trainer(self, trainer_cls, opts, *args, **kwargs):
+        return trainer_cls(
+            opts=opts,
+            updater=self.updater,
+            evaluator=self.evaluator,
+            *args, **kwargs
+        )
+
+    def run(self, trainer_cls, opts, *args, **kwargs):
+
+        trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
+
+        if self.manual_gc:
+            trainer.extend(gc_collect)
+
+        self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
+
+        logging.info("Snapshotting is {}abled".format("dis" if self.no_snapshot else "en"))
+
+        def dump(suffix):
+            if self.only_eval or self.no_snapshot:
+                return
+
+            save_npz(Path(trainer.out, f"clf_{suffix}.npz"), self.clf)
+            save_npz(Path(trainer.out, f"model_{suffix}.npz"), self.model)
+
+        try:
+            trainer.run(self.init_eval or self.only_eval)
+        except (KeyboardInterrupt, BdbQuit) as e:
+            raise e
+        except Exception as e:
+            dump("exception")
+            raise e
+        else:
+            dump("final")
+
+    def save_meta_info(self, opts, folder: Path):
+        folder.mkdir(parents=True, exist_ok=True)
+
+        with open(folder / "args.yml", "w") as f:
+            pyaml.dump(opts.__dict__, f, sort_keys=True)
 

+ 6 - 6
cvfinetune/finetuner/mpi.py

@@ -8,9 +8,9 @@ from cvfinetune.finetuner.base import DefaultFinetuner
 
 class MPIFinetuner(DefaultFinetuner):
 
-	def __init__(self, opts, *args, comm, **kwargs):
+	def __init__(self, *args, comm, **kwargs):
 		self.comm = comm
-		super(MPIFinetuner, self).__init__(opts, *args, **kwargs)
+		super(MPIFinetuner, self).__init__(*args, **kwargs)
 
 	@property
 	def mpi(self):
@@ -20,16 +20,16 @@ class MPIFinetuner(DefaultFinetuner):
 	def mpi_main_process(self):
 		return not (self.comm is not None and self.comm.rank != 0)
 
-	def gpu_config(self, opts):
+	def gpu_config(self, devices):
 
 		if not self.mpi:
 			msg = "Using MPIFinetuner without setting a communicator!"
 			warnings.warn(msg)
 			logging.warn(msg)
-			return super(MPIFinetuner, self).gpu_config(opts)
+			return super(MPIFinetuner, self).gpu_config(devices)
 
-		if len(opts.gpu) > 1:
-			self.device_id = opts.gpu[self.comm.rank]
+		if len(devices) > 1:
+			self.device_id = devices[self.comm.rank]
 		else:
 			self.device_id += self.comm.intra_rank