Bladeren bron

refactored finetuner module a bit

Dimitri Korsch 4 jaren geleden
bovenliggende
commit
184697051a

+ 2 - 51
cvfinetune/finetuner/__init__.py

@@ -1,58 +1,9 @@
-
-import logging
-try:
-	import chainermn
-except Exception as e: #pragma: no cover
-	_CHAINERMN_AVAILABLE = False #pragma: no cover
-else:
-	_CHAINERMN_AVAILABLE = True
-
-from cvfinetune import utils
 from cvfinetune.finetuner.base import DefaultFinetuner
 from cvfinetune.finetuner.base import DefaultFinetuner
+from cvfinetune.finetuner.factory import FinetunerFactory
 from cvfinetune.finetuner.mpi import MPIFinetuner
 from cvfinetune.finetuner.mpi import MPIFinetuner
 
 
-from cvdatasets.utils import pretty_print_dict
-
-class FinetunerFactory(object):
-
-	@classmethod
-	def new(cls, opts, default=DefaultFinetuner, mpi_tuner=MPIFinetuner):
-
-		if getattr(opts, "mpi", False):
-			assert _CHAINERMN_AVAILABLE, "Distributed training is not possible!"
-			msg1 = "MPI enabled. Creating NCCL communicator!"
-			comm = chainermn.create_communicator("pure_nccl")
-			msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
-
-			utils.log_messages([msg1, msg2])
-			return cls(mpi_tuner, comm=comm)
-		else:
-			return cls(default)
-
-	def __init__(self, tuner_cls, **kwargs):
-		super(FinetunerFactory, self).__init__()
-
-		self.tuner_cls = tuner_cls
-		self.kwargs = kwargs
-		logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
-
-	def __call__(self, **kwargs):
-		_kwargs = dict(self.kwargs)
-		_kwargs.update(kwargs)
-
-		return self.tuner_cls(**_kwargs)
-
-	def get(self, key, default=None):
-		return self.kwargs.get(key, default)
-
-	def __getitem__(self, key):
-		return self.kwargs[key]
-
-	def __setitem__(self, key, value):
-		self.kwargs[key] = value
-
 __all__ = [
 __all__ = [
-	"get_finetuner",
+	"FinetunerFactory",
 	"DefaultFinetuner",
 	"DefaultFinetuner",
 	"MPIFinetuner",
 	"MPIFinetuner",
 ]
 ]

+ 8 - 335
cvfinetune/finetuner/base.py

@@ -6,32 +6,31 @@ import abc
 import logging
 import logging
 import pyaml
 import pyaml
 
 
+from bdb import BdbQuit
 from chainer.backends import cuda
 from chainer.backends import cuda
 from chainer.optimizer_hooks import Lasso
 from chainer.optimizer_hooks import Lasso
 from chainer.optimizer_hooks import WeightDecay
 from chainer.optimizer_hooks import WeightDecay
 from chainer.serializers import save_npz
 from chainer.serializers import save_npz
 from chainer.training import extensions
 from chainer.training import extensions
-
-from chainercv2.model_provider import get_model
-from chainercv2.models import model_store
-
 from chainer_addons.functions import smoothed_cross_entropy
 from chainer_addons.functions import smoothed_cross_entropy
 from chainer_addons.models import Classifier
 from chainer_addons.models import Classifier
 from chainer_addons.models import ModelType
 from chainer_addons.models import ModelType
 from chainer_addons.models import PrepareType
 from chainer_addons.models import PrepareType
 from chainer_addons.training import optimizer
 from chainer_addons.training import optimizer
 from chainer_addons.training import optimizer_hooks
 from chainer_addons.training import optimizer_hooks
-
+from chainercv2.model_provider import get_model
+from chainercv2.models import model_store
 from cvdatasets import AnnotationType
 from cvdatasets import AnnotationType
 from cvdatasets.dataset.image import Size
 from cvdatasets.dataset.image import Size
 from cvdatasets.utils import new_iterator
 from cvdatasets.utils import new_iterator
 from cvdatasets.utils import pretty_print_dict
 from cvdatasets.utils import pretty_print_dict
-
-from bdb import BdbQuit
 from functools import partial
 from functools import partial
 from pathlib import Path
 from pathlib import Path
 
 
 
 
+
+from cvfinetune.finetuner import mixins
+
 def check_param_for_decay(param):
 def check_param_for_decay(param):
 	return param.name != "alpha"
 	return param.name != "alpha"
 
 
@@ -43,341 +42,15 @@ def enable_only_head(chain: chainer.Chain):
 		chain.disable_update()
 		chain.disable_update()
 		chain.fc.enable_update()
 		chain.fc.enable_update()
 
 
-class _ModelMixin(abc.ABC):
-	"""This mixin is responsible for optimizer creation, model creation,
-	model wrapping around a classifier and model weights loading.
-	"""
-
-	def __init__(self, opts, classifier_cls, classifier_kwargs={}, model_kwargs={}, *args, **kwargs):
-		super(_ModelMixin, self).__init__(opts=opts, *args, **kwargs)
-		self.classifier_cls = classifier_cls
-		self.classifier_kwargs = classifier_kwargs
-		self.model_type = opts.model_type
-		self.model_kwargs = model_kwargs
-
-
-	@property
-	def model_info(self):
-		return self.data_info.MODELS[self.model_type]
-
-	def wrap_model(self, opts):
-
-		clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
-
-		self.clf = clf_class(
-			model=self.model,
-			loss_func=self._loss_func(opts),
-			**kwargs)
-
-		logging.info(" ".join([
-			f"Wrapped the model around {clf_class.__name__}",
-			f"with kwargs: {pretty_print_dict(kwargs)}",
-		]))
-
-	def _loss_func(self, opts):
-		if getattr(opts, "l1_loss", False):
-			return F.hinge
-
-		elif getattr(opts, "label_smoothing", 0) >= 0:
-			assert getattr(opts, "label_smoothing", 0) < 1, \
-				"Label smoothing factor must be less than 1!"
-			return partial(smoothed_cross_entropy,
-				N=self.n_classes,
-				eps=getattr(opts, "label_smoothing", 0))
-		else:
-			return F.softmax_cross_entropy
-
-	def init_optimizer(self, opts):
-		"""Creates an optimizer for the classifier """
-		if not hasattr(opts, "optimizer"):
-			self.opt = None
-			return
-
-		opt_kwargs = {}
-		if opts.optimizer == "rmsprop":
-			opt_kwargs["alpha"] = 0.9
-
-		self.opt = optimizer(opts.optimizer,
-			self.clf,
-			opts.learning_rate,
-			decay=0, gradient_clipping=False, **opt_kwargs
-		)
-
-		if opts.decay > 0:
-			reg_kwargs = {}
-			if opts.l1_loss:
-				reg_cls = Lasso
-
-			elif opts.pooling == "alpha":
-				reg_cls = optimizer_hooks.SelectiveWeightDecay
-				reg_kwargs["selection"] = check_param_for_decay
-
-			else:
-				reg_cls = WeightDecay
-
-			logging.info(f"Adding {reg_cls.__name__} ({opts.decay:e})")
-			self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
-
-		if getattr(opts, "only_head", False):
-			assert not getattr(opts, "recurrent", False), \
-				"Recurrent classifier is not supported with only_head option!"
-
-			logging.warning("========= Fine-tuning only classifier layer! =========")
-			enable_only_head(self.clf)
-
-	def init_model(self, opts):
-		"""creates backbone CNN model. This model is wrapped around the classifier later"""
-
-		if self.model_type.startswith("cv2_"):
-			model_type = args.model_type.split("cv2_")[-1]
-		else:
-			model_type = self.model_info.class_key
-
-			# model = get_model(model_type, pretrained=False)
-
-		self.model = ModelType.new(
-			model_type=model_type,
-			input_size=Size(opts.input_size),
-			**self.model_kwargs,
-		)
-
-	def load_model_weights(self, args):
-		if getattr(args, "from_scratch", False):
-			logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
-			loader = self.model.reinitialize_clf
-			self.weights = None
-		else:
-			if args.load:
-				self.weights = args.load
-				msg = "Loading already fine-tuned weights from \"{}\""
-				loader_func = self.model.load_for_inference
-			else:
-				if args.weights:
-					msg = "Loading custom pre-trained weights \"{}\""
-					self.weights = args.weights
-
-				else:
-					msg = "Loading default pre-trained weights \"{}\""
-					self.weights = str(Path(
-						self.data_info.BASE_DIR,
-						self.data_info.MODEL_DIR,
-						self.model_info.folder,
-						self.model_info.weights
-					))
-
-				loader_func = self.model.load_for_finetune
-
-			logging.info(msg.format(self.weights))
-			kwargs = dict(
-				weights=self.weights,
-				strict=args.load_strict,
-				path=args.load_path,
-				headless=args.headless,
-			)
-			loader = partial(loader_func, **kwargs)
-
-		feat_size = self.model.meta.feature_size
-
-		if hasattr(self.clf, "output_size"):
-			feat_size = self.clf.output_size
-
-		if hasattr(self.clf, "loader"):
-			loader = self.clf.loader(loader)
-
-		logging.info(f"Part features size after encoding: {feat_size}")
-		loader(n_classes=self.n_classes, feat_size=feat_size)
-		self.clf.cleargrads()
-
-class _DatasetMixin(abc.ABC):
-	"""
-		This mixin is responsible for annotation loading and for
-		dataset and iterator creation.
-	"""
-
-	def __init__(self, opts, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
-		super(_DatasetMixin, self).__init__(opts=opts, *args, **kwargs)
-		self.annot = None
-		self.dataset_type = opts.dataset
-		self.dataset_cls = dataset_cls
-		self.dataset_kwargs_factory = dataset_kwargs_factory
-
-	@property
-	def n_classes(self):
-		return self.ds_info.n_classes + self.dataset_cls.label_shift
-
-	@property
-	def data_info(self):
-		assert self.annot is not None, "annot attribute was not set!"
-		return self.annot.info
-
-	@property
-	def ds_info(self):
-		return self.data_info.DATASETS[self.dataset_type]
-
-	def new_dataset(self, opts, size, part_size, subset):
-		"""Creates a dataset for a specific subset and certain options"""
-		if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
-			kwargs = self.dataset_kwargs_factory(opts, subset)
-		else:
-			kwargs = dict()
-
-		kwargs = dict(kwargs,
-			subset=subset,
-			dataset_cls=self.dataset_cls,
-			prepare=self.prepare,
-			size=size,
-			part_size=part_size,
-			center_crop_on_val=getattr(opts, "center_crop_on_val", False),
-		)
-
-
-		ds = self.annot.new_dataset(**kwargs)
-		logging.info("Loaded {} images".format(len(ds)))
-		return ds
-
-
-	def init_annotations(self, opts):
-		"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
-
-		self.annot = AnnotationType.new_annotation(opts, load_strict=False)
-		self.dataset_cls.label_shift = opts.label_shift
-
-
-	def init_datasets(self, opts):
-
-		size = Size(opts.input_size)
-		part_size = getattr(opts, "parts_input_size", None)
-		part_size = size if part_size is None else Size(part_size)
-
-		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
-			swap_channels=opts.swap_channels,
-			keep_ratio=getattr(opts, "center_crop_on_val", False),
-		)
-
-		logging.info(" ".join([
-			f"Created {self.model.__class__.__name__} model",
-			f"with \"{opts.prepare_type}\" prepare function."
-		]))
-
-		logging.info(" ".join([
-			f"Image input size: {size}",
-			f"Image parts input size: {part_size}",
-		]))
-
-		self.train_data = self.new_dataset(opts, size, part_size, "train")
-		self.val_data = self.new_dataset(opts, size, part_size, "test")
-
-	def init_iterators(self, opts):
-		"""Creates training and validation iterators from training and validation datasets"""
-
-		kwargs = dict(n_jobs=opts.n_jobs, batch_size=opts.batch_size)
-
-		if hasattr(self.train_data, "new_iterator"):
-			self.train_iter, _ = self.train_data.new_iterator(**kwargs)
-		else:
-			self.train_iter, _ = new_iterator(self.train_data, **kwargs)
-
-		if hasattr(self.val_data, "new_iterator"):
-			self.val_iter, _ = self.val_data.new_iterator(**kwargs,
-				repeat=False, shuffle=False
-			)
-		else:
-			self.val_iter, _ = new_iterator(self.val_data,
-				**kwargs, repeat=False, shuffle=False
-			)
-
-
-class _TrainerMixin(abc.ABC):
-	"""This mixin is responsible for updater, evaluator and trainer creation.
-	Furthermore, it implements the run method
-	"""
-
-	def __init__(self, updater_cls, updater_kwargs={}, *args, **kwargs):
-		super(_TrainerMixin, self).__init__(*args, **kwargs)
-		self.updater_cls = updater_cls
-		self.updater_kwargs = updater_kwargs
-
-	def init_updater(self):
-		"""Creates an updater from training iterator and the optimizer."""
-
-		if self.opt is None:
-			self.updater = None
-			return
-
-		self.updater = self.updater_cls(
-			iterator=self.train_iter,
-			optimizer=self.opt,
-			device=self.device,
-			**self.updater_kwargs,
-		)
-		logging.info(" ".join([
-			f"Using single GPU: {self.device}.",
-			f"{self.updater_cls.__name__} is initialized",
-			f"with following kwargs: {pretty_print_dict(self.updater_kwargs)}"
-			])
-		)
-
-	def init_evaluator(self, default_name="val"):
-		"""Creates evaluation extension from validation iterator and the classifier."""
-
-		self.evaluator = extensions.Evaluator(
-			iterator=self.val_iter,
-			target=self.clf,
-			device=self.device,
-			progress_bar=True
-		)
-
-		self.evaluator.default_name = default_name
-
-	def _new_trainer(self, trainer_cls, opts, *args, **kwargs):
-		return trainer_cls(
-			opts=opts,
-			updater=self.updater,
-			evaluator=self.evaluator,
-			*args, **kwargs
-		)
-
-	def run(self, trainer_cls, opts, *args, **kwargs):
-
-		trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
-
-		self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
-
-		logging.info("Snapshotting is {}abled".format("dis" if opts.no_snapshot else "en"))
-
-		def dump(suffix):
-			if opts.only_eval or opts.no_snapshot:
-				return
-
-			save_npz(Path(trainer.out, f"clf_{suffix}.npz"), self.clf)
-			save_npz(Path(trainer.out, f"model_{suffix}.npz"), self.model)
-
-		try:
-			trainer.run(opts.init_eval or opts.only_eval)
-		except (KeyboardInterrupt, BdbQuit) as e:
-			raise e
-		except Exception as e:
-			dump("exception")
-			raise e
-		else:
-			dump("final")
-
-	def save_meta_info(self, opts, folder: Path):
-		folder.mkdir(parents=True, exist_ok=True)
-
-		with open(folder / "args.yml", "w") as f:
-			pyaml.dump(opts.__dict__, f, sort_keys=True)
-
-
 
 
-class DefaultFinetuner(_ModelMixin, _DatasetMixin, _TrainerMixin):
+class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._TrainerMixin):
 	""" The default Finetuner gathers together the creations of all needed
 	""" The default Finetuner gathers together the creations of all needed
 	components and call them in the correct order
 	components and call them in the correct order
 
 
 	"""
 	"""
 
 
 	def __init__(self, opts, *args, **kwargs):
 	def __init__(self, opts, *args, **kwargs):
-		super(DefaultFinetuner, self).__init__(*args, **kwargs)
+		super(DefaultFinetuner, self).__init__(opts=opts, *args, **kwargs)
 
 
 		self.gpu_config(opts)
 		self.gpu_config(opts)
 		cuda.get_device_from_id(self.device).use()
 		cuda.get_device_from_id(self.device).use()

+ 50 - 0
cvfinetune/finetuner/factory.py

@@ -0,0 +1,50 @@
+import logging
+try:
+	import chainermn
+except Exception as e: #pragma: no cover
+	_CHAINERMN_AVAILABLE = False #pragma: no cover
+else:
+	_CHAINERMN_AVAILABLE = True
+
+from cvfinetune import utils
+from cvfinetune.finetuner.base import DefaultFinetuner
+from cvfinetune.finetuner.mpi import MPIFinetuner
+
+from cvdatasets.utils import pretty_print_dict
+
+class FinetunerFactory(object):
+
+	@classmethod
+	def new(cls, opts, default=DefaultFinetuner, mpi_tuner=MPIFinetuner):
+
+		if getattr(opts, "mpi", False):
+			assert _CHAINERMN_AVAILABLE, "Distributed training is not possible!"
+			msg1 = "MPI enabled. Creating NCCL communicator!"
+			comm = chainermn.create_communicator("pure_nccl")
+			msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
+
+			utils.log_messages([msg1, msg2])
+			return cls(mpi_tuner, comm=comm)
+		else:
+			return cls(default)
+
+	def __init__(self, tuner_cls, **kwargs):
+		super(FinetunerFactory, self).__init__()
+
+		self.tuner_cls = tuner_cls
+		self.kwargs = kwargs
+		logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
+
+	def __call__(self, **kwargs):
+		_kwargs = dict(self.kwargs, **kwargs)
+
+		return self.tuner_cls(**_kwargs)
+
+	def get(self, key, default=None):
+		return self.kwargs.get(key, default)
+
+	def __getitem__(self, key):
+		return self.kwargs[key]
+
+	def __setitem__(self, key, value):
+		self.kwargs[key] = value

+ 10 - 0
cvfinetune/finetuner/mixins/__init__.py

@@ -0,0 +1,10 @@
+from cvfinetune.finetuner.mixins.dataset import _DatasetMixin
+from cvfinetune.finetuner.mixins.model import _ModelMixin
+from cvfinetune.finetuner.mixins.trainer import _TrainerMixin
+
+
+__all__ = [
+	"_DatasetMixin",
+	"_ModelMixin",
+	"_TrainerMixin",
+]

+ 108 - 0
cvfinetune/finetuner/mixins/dataset.py

@@ -0,0 +1,108 @@
+import abc
+import logging
+
+from chainer_addons.models import PrepareType
+from cvdatasets import AnnotationType
+from cvdatasets.dataset.image import Size
+from cvdatasets.utils import new_iterator
+from functools import partial
+
+
+class _DatasetMixin(abc.ABC):
+	"""
+		This mixin is responsible for annotation loading and for
+		dataset and iterator creation.
+	"""
+
+	def __init__(self, opts, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
+		super(_DatasetMixin, self).__init__(opts=opts, *args, **kwargs)
+		self.annot = None
+		self.dataset_type = opts.dataset
+		self.dataset_cls = dataset_cls
+		self.dataset_kwargs_factory = dataset_kwargs_factory
+
+	@property
+	def n_classes(self):
+		return self.ds_info.n_classes + self.dataset_cls.label_shift
+
+	@property
+	def data_info(self):
+		assert self.annot is not None, "annot attribute was not set!"
+		return self.annot.info
+
+	@property
+	def ds_info(self):
+		return self.data_info.DATASETS[self.dataset_type]
+
+	def new_dataset(self, opts, size, part_size, subset):
+		"""Creates a dataset for a specific subset and certain options"""
+		if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
+			kwargs = self.dataset_kwargs_factory(opts, subset)
+		else:
+			kwargs = dict()
+
+		kwargs = dict(kwargs,
+			subset=subset,
+			dataset_cls=self.dataset_cls,
+			prepare=self.prepare,
+			size=size,
+			part_size=part_size,
+			center_crop_on_val=getattr(opts, "center_crop_on_val", False),
+		)
+
+
+		ds = self.annot.new_dataset(**kwargs)
+		logging.info("Loaded {} images".format(len(ds)))
+		return ds
+
+
+	def init_annotations(self, opts):
+		"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
+
+		self.annot = AnnotationType.new_annotation(opts, load_strict=False)
+		self.dataset_cls.label_shift = opts.label_shift
+
+
+	def init_datasets(self, opts):
+
+		size = Size(opts.input_size)
+		part_size = getattr(opts, "parts_input_size", None)
+		part_size = size if part_size is None else Size(part_size)
+
+		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
+			swap_channels=opts.swap_channels,
+			keep_ratio=getattr(opts, "center_crop_on_val", False),
+		)
+
+		logging.info(" ".join([
+			f"Created {self.model.__class__.__name__} model",
+			f"with \"{opts.prepare_type}\" prepare function."
+		]))
+
+		logging.info(" ".join([
+			f"Image input size: {size}",
+			f"Image parts input size: {part_size}",
+		]))
+
+		self.train_data = self.new_dataset(opts, size, part_size, "train")
+		self.val_data = self.new_dataset(opts, size, part_size, "test")
+
+	def init_iterators(self, opts):
+		"""Creates training and validation iterators from training and validation datasets"""
+
+		kwargs = dict(n_jobs=opts.n_jobs, batch_size=opts.batch_size)
+
+		if hasattr(self.train_data, "new_iterator"):
+			self.train_iter, _ = self.train_data.new_iterator(**kwargs)
+		else:
+			self.train_iter, _ = new_iterator(self.train_data, **kwargs)
+
+		if hasattr(self.val_data, "new_iterator"):
+			self.val_iter, _ = self.val_data.new_iterator(**kwargs,
+				repeat=False, shuffle=False
+			)
+		else:
+			self.val_iter, _ = new_iterator(self.val_data,
+				**kwargs, repeat=False, shuffle=False
+			)
+

+ 159 - 0
cvfinetune/finetuner/mixins/model.py

@@ -0,0 +1,159 @@
+import abc
+import chainer.functions as F
+import logging
+
+from chainer.optimizer_hooks import Lasso
+from chainer.optimizer_hooks import WeightDecay
+from chainer_addons.functions import smoothed_cross_entropy
+from chainer_addons.models import ModelType
+from chainer_addons.training import optimizer
+from chainer_addons.training import optimizer_hooks
+from cvdatasets.dataset.image import Size
+from cvdatasets.utils import pretty_print_dict
+from functools import partial
+from pathlib import Path
+
+
+class _ModelMixin(abc.ABC):
+	"""
+		This mixin is responsible for optimizer creation, model creation,
+		model wrapping around a classifier and model weights loading.
+	"""
+
+	def __init__(self, opts, classifier_cls, classifier_kwargs={}, model_kwargs={}, *args, **kwargs):
+		super(_ModelMixin, self).__init__(opts=opts, *args, **kwargs)
+		self.classifier_cls = classifier_cls
+		self.classifier_kwargs = classifier_kwargs
+		self.model_type = opts.model_type
+		self.model_kwargs = model_kwargs
+
+
+	@property
+	def model_info(self):
+		return self.data_info.MODELS[self.model_type]
+
+	def wrap_model(self, opts):
+
+		clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
+
+		self.clf = clf_class(
+			model=self.model,
+			loss_func=self._loss_func(opts),
+			**kwargs)
+
+		logging.info(" ".join([
+			f"Wrapped the model around {clf_class.__name__}",
+			f"with kwargs: {pretty_print_dict(kwargs)}",
+		]))
+
+	def _loss_func(self, opts):
+		if getattr(opts, "l1_loss", False):
+			return F.hinge
+
+		elif getattr(opts, "label_smoothing", 0) >= 0:
+			assert getattr(opts, "label_smoothing", 0) < 1, \
+				"Label smoothing factor must be less than 1!"
+			return partial(smoothed_cross_entropy,
+				N=self.n_classes,
+				eps=getattr(opts, "label_smoothing", 0))
+		else:
+			return F.softmax_cross_entropy
+
+	def init_optimizer(self, opts):
+		"""Creates an optimizer for the classifier """
+		if not hasattr(opts, "optimizer"):
+			self.opt = None
+			return
+
+		opt_kwargs = {}
+		if opts.optimizer == "rmsprop":
+			opt_kwargs["alpha"] = 0.9
+
+		self.opt = optimizer(opts.optimizer,
+			self.clf,
+			opts.learning_rate,
+			decay=0, gradient_clipping=False, **opt_kwargs
+		)
+
+		if opts.decay > 0:
+			reg_kwargs = {}
+			if opts.l1_loss:
+				reg_cls = Lasso
+
+			elif opts.pooling == "alpha":
+				reg_cls = optimizer_hooks.SelectiveWeightDecay
+				reg_kwargs["selection"] = check_param_for_decay
+
+			else:
+				reg_cls = WeightDecay
+
+			logging.info(f"Adding {reg_cls.__name__} ({opts.decay:e})")
+			self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
+
+		if getattr(opts, "only_head", False):
+			assert not getattr(opts, "recurrent", False), \
+				"Recurrent classifier is not supported with only_head option!"
+
+			logging.warning("========= Fine-tuning only classifier layer! =========")
+			enable_only_head(self.clf)
+
+	def init_model(self, opts):
+		"""creates backbone CNN model. This model is wrapped around the classifier later"""
+
+		if self.model_type.startswith("cv2_"):
+			model_type = args.model_type.split("cv2_")[-1]
+		else:
+			model_type = self.model_info.class_key
+
+		self.model = ModelType.new(
+			model_type=model_type,
+			input_size=Size(opts.input_size),
+			**self.model_kwargs,
+		)
+
+	def load_model_weights(self, args):
+		if getattr(args, "from_scratch", False):
+			logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
+			loader = self.model.reinitialize_clf
+			self.weights = None
+		else:
+			if args.load:
+				self.weights = args.load
+				msg = "Loading already fine-tuned weights from \"{}\""
+				loader_func = self.model.load_for_inference
+			else:
+				if args.weights:
+					msg = "Loading custom pre-trained weights \"{}\""
+					self.weights = args.weights
+
+				else:
+					msg = "Loading default pre-trained weights \"{}\""
+					self.weights = str(Path(
+						self.data_info.BASE_DIR,
+						self.data_info.MODEL_DIR,
+						self.model_info.folder,
+						self.model_info.weights
+					))
+
+				loader_func = self.model.load_for_finetune
+
+			logging.info(msg.format(self.weights))
+			kwargs = dict(
+				weights=self.weights,
+				strict=args.load_strict,
+				path=args.load_path,
+				headless=args.headless,
+			)
+			loader = partial(loader_func, **kwargs)
+
+		feat_size = self.model.meta.feature_size
+
+		if hasattr(self.clf, "output_size"):
+			feat_size = self.clf.output_size
+
+		if hasattr(self.clf, "loader"):
+			loader = self.clf.loader(loader)
+
+		logging.info(f"Part features size after encoding: {feat_size}")
+		loader(n_classes=self.n_classes, feat_size=feat_size)
+		self.clf.cleargrads()

+ 92 - 0
cvfinetune/finetuner/mixins/trainer.py

@@ -0,0 +1,92 @@
+import abc
+import logging
+import pyaml
+
+from bdb import BdbQuit
+from chainer.serializers import save_npz
+from chainer.training import extensions
+from cvdatasets.utils import pretty_print_dict
+from pathlib import Path
+
+
+class _TrainerMixin(abc.ABC):
+	"""This mixin is responsible for updater, evaluator and trainer creation.
+	Furthermore, it implements the run method
+	"""
+
+	def __init__(self, opts, updater_cls, updater_kwargs={}, *args, **kwargs):
+		super(_TrainerMixin, self).__init__(*args, **kwargs)
+		self.updater_cls = updater_cls
+		self.updater_kwargs = updater_kwargs
+
+	def init_updater(self):
+		"""Creates an updater from training iterator and the optimizer."""
+
+		if self.opt is None:
+			self.updater = None
+			return
+
+		self.updater = self.updater_cls(
+			iterator=self.train_iter,
+			optimizer=self.opt,
+			device=self.device,
+			**self.updater_kwargs,
+		)
+		logging.info(" ".join([
+			f"Using single GPU: {self.device}.",
+			f"{self.updater_cls.__name__} is initialized",
+			f"with following kwargs: {pretty_print_dict(self.updater_kwargs)}"
+			])
+		)
+
+	def init_evaluator(self, default_name="val"):
+		"""Creates evaluation extension from validation iterator and the classifier."""
+
+		self.evaluator = extensions.Evaluator(
+			iterator=self.val_iter,
+			target=self.clf,
+			device=self.device,
+			progress_bar=True
+		)
+
+		self.evaluator.default_name = default_name
+
+	def _new_trainer(self, trainer_cls, opts, *args, **kwargs):
+		return trainer_cls(
+			opts=opts,
+			updater=self.updater,
+			evaluator=self.evaluator,
+			*args, **kwargs
+		)
+
+	def run(self, trainer_cls, opts, *args, **kwargs):
+
+		trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
+
+		self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
+
+		logging.info("Snapshotting is {}abled".format("dis" if opts.no_snapshot else "en"))
+
+		def dump(suffix):
+			if opts.only_eval or opts.no_snapshot:
+				return
+
+			save_npz(Path(trainer.out, f"clf_{suffix}.npz"), self.clf)
+			save_npz(Path(trainer.out, f"model_{suffix}.npz"), self.model)
+
+		try:
+			trainer.run(opts.init_eval or opts.only_eval)
+		except (KeyboardInterrupt, BdbQuit) as e:
+			raise e
+		except Exception as e:
+			dump("exception")
+			raise e
+		else:
+			dump("final")
+
+	def save_meta_info(self, opts, folder: Path):
+		folder.mkdir(parents=True, exist_ok=True)
+
+		with open(folder / "args.yml", "w") as f:
+			pyaml.dump(opts.__dict__, f, sort_keys=True)
+