Преглед изворни кода

refactored finetuner module a bit

Dimitri Korsch пре 4 година
родитељ
комит
184697051a

+ 2 - 51
cvfinetune/finetuner/__init__.py

@@ -1,58 +1,9 @@
-
-import logging
-try:
-	import chainermn
-except Exception as e: #pragma: no cover
-	_CHAINERMN_AVAILABLE = False #pragma: no cover
-else:
-	_CHAINERMN_AVAILABLE = True
-
-from cvfinetune import utils
 from cvfinetune.finetuner.base import DefaultFinetuner
+from cvfinetune.finetuner.factory import FinetunerFactory
 from cvfinetune.finetuner.mpi import MPIFinetuner
 
-from cvdatasets.utils import pretty_print_dict
-
-class FinetunerFactory(object):
-
-	@classmethod
-	def new(cls, opts, default=DefaultFinetuner, mpi_tuner=MPIFinetuner):
-
-		if getattr(opts, "mpi", False):
-			assert _CHAINERMN_AVAILABLE, "Distributed training is not possible!"
-			msg1 = "MPI enabled. Creating NCCL communicator!"
-			comm = chainermn.create_communicator("pure_nccl")
-			msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
-
-			utils.log_messages([msg1, msg2])
-			return cls(mpi_tuner, comm=comm)
-		else:
-			return cls(default)
-
-	def __init__(self, tuner_cls, **kwargs):
-		super(FinetunerFactory, self).__init__()
-
-		self.tuner_cls = tuner_cls
-		self.kwargs = kwargs
-		logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
-
-	def __call__(self, **kwargs):
-		_kwargs = dict(self.kwargs)
-		_kwargs.update(kwargs)
-
-		return self.tuner_cls(**_kwargs)
-
-	def get(self, key, default=None):
-		return self.kwargs.get(key, default)
-
-	def __getitem__(self, key):
-		return self.kwargs[key]
-
-	def __setitem__(self, key, value):
-		self.kwargs[key] = value
-
 __all__ = [
-	"get_finetuner",
+	"FinetunerFactory",
 	"DefaultFinetuner",
 	"MPIFinetuner",
 ]

+ 8 - 335
cvfinetune/finetuner/base.py

@@ -6,32 +6,31 @@ import abc
 import logging
 import pyaml
 
+from bdb import BdbQuit
 from chainer.backends import cuda
 from chainer.optimizer_hooks import Lasso
 from chainer.optimizer_hooks import WeightDecay
 from chainer.serializers import save_npz
 from chainer.training import extensions
-
-from chainercv2.model_provider import get_model
-from chainercv2.models import model_store
-
 from chainer_addons.functions import smoothed_cross_entropy
 from chainer_addons.models import Classifier
 from chainer_addons.models import ModelType
 from chainer_addons.models import PrepareType
 from chainer_addons.training import optimizer
 from chainer_addons.training import optimizer_hooks
-
+from chainercv2.model_provider import get_model
+from chainercv2.models import model_store
 from cvdatasets import AnnotationType
 from cvdatasets.dataset.image import Size
 from cvdatasets.utils import new_iterator
 from cvdatasets.utils import pretty_print_dict
-
-from bdb import BdbQuit
 from functools import partial
 from pathlib import Path
 
 
+
+from cvfinetune.finetuner import mixins
+
 def check_param_for_decay(param):
 	return param.name != "alpha"
 
@@ -43,341 +42,15 @@ def enable_only_head(chain: chainer.Chain):
 		chain.disable_update()
 		chain.fc.enable_update()
 
-class _ModelMixin(abc.ABC):
-	"""This mixin is responsible for optimizer creation, model creation,
-	model wrapping around a classifier and model weights loading.
-	"""
-
-	def __init__(self, opts, classifier_cls, classifier_kwargs={}, model_kwargs={}, *args, **kwargs):
-		super(_ModelMixin, self).__init__(opts=opts, *args, **kwargs)
-		self.classifier_cls = classifier_cls
-		self.classifier_kwargs = classifier_kwargs
-		self.model_type = opts.model_type
-		self.model_kwargs = model_kwargs
-
-
-	@property
-	def model_info(self):
-		return self.data_info.MODELS[self.model_type]
-
-	def wrap_model(self, opts):
-
-		clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
-
-		self.clf = clf_class(
-			model=self.model,
-			loss_func=self._loss_func(opts),
-			**kwargs)
-
-		logging.info(" ".join([
-			f"Wrapped the model around {clf_class.__name__}",
-			f"with kwargs: {pretty_print_dict(kwargs)}",
-		]))
-
-	def _loss_func(self, opts):
-		if getattr(opts, "l1_loss", False):
-			return F.hinge
-
-		elif getattr(opts, "label_smoothing", 0) >= 0:
-			assert getattr(opts, "label_smoothing", 0) < 1, \
-				"Label smoothing factor must be less than 1!"
-			return partial(smoothed_cross_entropy,
-				N=self.n_classes,
-				eps=getattr(opts, "label_smoothing", 0))
-		else:
-			return F.softmax_cross_entropy
-
-	def init_optimizer(self, opts):
-		"""Creates an optimizer for the classifier """
-		if not hasattr(opts, "optimizer"):
-			self.opt = None
-			return
-
-		opt_kwargs = {}
-		if opts.optimizer == "rmsprop":
-			opt_kwargs["alpha"] = 0.9
-
-		self.opt = optimizer(opts.optimizer,
-			self.clf,
-			opts.learning_rate,
-			decay=0, gradient_clipping=False, **opt_kwargs
-		)
-
-		if opts.decay > 0:
-			reg_kwargs = {}
-			if opts.l1_loss:
-				reg_cls = Lasso
-
-			elif opts.pooling == "alpha":
-				reg_cls = optimizer_hooks.SelectiveWeightDecay
-				reg_kwargs["selection"] = check_param_for_decay
-
-			else:
-				reg_cls = WeightDecay
-
-			logging.info(f"Adding {reg_cls.__name__} ({opts.decay:e})")
-			self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
-
-		if getattr(opts, "only_head", False):
-			assert not getattr(opts, "recurrent", False), \
-				"Recurrent classifier is not supported with only_head option!"
-
-			logging.warning("========= Fine-tuning only classifier layer! =========")
-			enable_only_head(self.clf)
-
-	def init_model(self, opts):
-		"""creates backbone CNN model. This model is wrapped around the classifier later"""
-
-		if self.model_type.startswith("cv2_"):
-			model_type = args.model_type.split("cv2_")[-1]
-		else:
-			model_type = self.model_info.class_key
-
-			# model = get_model(model_type, pretrained=False)
-
-		self.model = ModelType.new(
-			model_type=model_type,
-			input_size=Size(opts.input_size),
-			**self.model_kwargs,
-		)
-
-	def load_model_weights(self, args):
-		if getattr(args, "from_scratch", False):
-			logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
-			loader = self.model.reinitialize_clf
-			self.weights = None
-		else:
-			if args.load:
-				self.weights = args.load
-				msg = "Loading already fine-tuned weights from \"{}\""
-				loader_func = self.model.load_for_inference
-			else:
-				if args.weights:
-					msg = "Loading custom pre-trained weights \"{}\""
-					self.weights = args.weights
-
-				else:
-					msg = "Loading default pre-trained weights \"{}\""
-					self.weights = str(Path(
-						self.data_info.BASE_DIR,
-						self.data_info.MODEL_DIR,
-						self.model_info.folder,
-						self.model_info.weights
-					))
-
-				loader_func = self.model.load_for_finetune
-
-			logging.info(msg.format(self.weights))
-			kwargs = dict(
-				weights=self.weights,
-				strict=args.load_strict,
-				path=args.load_path,
-				headless=args.headless,
-			)
-			loader = partial(loader_func, **kwargs)
-
-		feat_size = self.model.meta.feature_size
-
-		if hasattr(self.clf, "output_size"):
-			feat_size = self.clf.output_size
-
-		if hasattr(self.clf, "loader"):
-			loader = self.clf.loader(loader)
-
-		logging.info(f"Part features size after encoding: {feat_size}")
-		loader(n_classes=self.n_classes, feat_size=feat_size)
-		self.clf.cleargrads()
-
-class _DatasetMixin(abc.ABC):
-	"""
-		This mixin is responsible for annotation loading and for
-		dataset and iterator creation.
-	"""
-
-	def __init__(self, opts, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
-		super(_DatasetMixin, self).__init__(opts=opts, *args, **kwargs)
-		self.annot = None
-		self.dataset_type = opts.dataset
-		self.dataset_cls = dataset_cls
-		self.dataset_kwargs_factory = dataset_kwargs_factory
-
-	@property
-	def n_classes(self):
-		return self.ds_info.n_classes + self.dataset_cls.label_shift
-
-	@property
-	def data_info(self):
-		assert self.annot is not None, "annot attribute was not set!"
-		return self.annot.info
-
-	@property
-	def ds_info(self):
-		return self.data_info.DATASETS[self.dataset_type]
-
-	def new_dataset(self, opts, size, part_size, subset):
-		"""Creates a dataset for a specific subset and certain options"""
-		if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
-			kwargs = self.dataset_kwargs_factory(opts, subset)
-		else:
-			kwargs = dict()
-
-		kwargs = dict(kwargs,
-			subset=subset,
-			dataset_cls=self.dataset_cls,
-			prepare=self.prepare,
-			size=size,
-			part_size=part_size,
-			center_crop_on_val=getattr(opts, "center_crop_on_val", False),
-		)
-
-
-		ds = self.annot.new_dataset(**kwargs)
-		logging.info("Loaded {} images".format(len(ds)))
-		return ds
-
-
-	def init_annotations(self, opts):
-		"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
-
-		self.annot = AnnotationType.new_annotation(opts, load_strict=False)
-		self.dataset_cls.label_shift = opts.label_shift
-
-
-	def init_datasets(self, opts):
-
-		size = Size(opts.input_size)
-		part_size = getattr(opts, "parts_input_size", None)
-		part_size = size if part_size is None else Size(part_size)
-
-		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
-			swap_channels=opts.swap_channels,
-			keep_ratio=getattr(opts, "center_crop_on_val", False),
-		)
-
-		logging.info(" ".join([
-			f"Created {self.model.__class__.__name__} model",
-			f"with \"{opts.prepare_type}\" prepare function."
-		]))
-
-		logging.info(" ".join([
-			f"Image input size: {size}",
-			f"Image parts input size: {part_size}",
-		]))
-
-		self.train_data = self.new_dataset(opts, size, part_size, "train")
-		self.val_data = self.new_dataset(opts, size, part_size, "test")
-
-	def init_iterators(self, opts):
-		"""Creates training and validation iterators from training and validation datasets"""
-
-		kwargs = dict(n_jobs=opts.n_jobs, batch_size=opts.batch_size)
-
-		if hasattr(self.train_data, "new_iterator"):
-			self.train_iter, _ = self.train_data.new_iterator(**kwargs)
-		else:
-			self.train_iter, _ = new_iterator(self.train_data, **kwargs)
-
-		if hasattr(self.val_data, "new_iterator"):
-			self.val_iter, _ = self.val_data.new_iterator(**kwargs,
-				repeat=False, shuffle=False
-			)
-		else:
-			self.val_iter, _ = new_iterator(self.val_data,
-				**kwargs, repeat=False, shuffle=False
-			)
-
-
-class _TrainerMixin(abc.ABC):
-	"""This mixin is responsible for updater, evaluator and trainer creation.
-	Furthermore, it implements the run method
-	"""
-
-	def __init__(self, updater_cls, updater_kwargs={}, *args, **kwargs):
-		super(_TrainerMixin, self).__init__(*args, **kwargs)
-		self.updater_cls = updater_cls
-		self.updater_kwargs = updater_kwargs
-
-	def init_updater(self):
-		"""Creates an updater from training iterator and the optimizer."""
-
-		if self.opt is None:
-			self.updater = None
-			return
-
-		self.updater = self.updater_cls(
-			iterator=self.train_iter,
-			optimizer=self.opt,
-			device=self.device,
-			**self.updater_kwargs,
-		)
-		logging.info(" ".join([
-			f"Using single GPU: {self.device}.",
-			f"{self.updater_cls.__name__} is initialized",
-			f"with following kwargs: {pretty_print_dict(self.updater_kwargs)}"
-			])
-		)
-
-	def init_evaluator(self, default_name="val"):
-		"""Creates evaluation extension from validation iterator and the classifier."""
-
-		self.evaluator = extensions.Evaluator(
-			iterator=self.val_iter,
-			target=self.clf,
-			device=self.device,
-			progress_bar=True
-		)
-
-		self.evaluator.default_name = default_name
-
-	def _new_trainer(self, trainer_cls, opts, *args, **kwargs):
-		return trainer_cls(
-			opts=opts,
-			updater=self.updater,
-			evaluator=self.evaluator,
-			*args, **kwargs
-		)
-
-	def run(self, trainer_cls, opts, *args, **kwargs):
-
-		trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
-
-		self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
-
-		logging.info("Snapshotting is {}abled".format("dis" if opts.no_snapshot else "en"))
-
-		def dump(suffix):
-			if opts.only_eval or opts.no_snapshot:
-				return
-
-			save_npz(Path(trainer.out, f"clf_{suffix}.npz"), self.clf)
-			save_npz(Path(trainer.out, f"model_{suffix}.npz"), self.model)
-
-		try:
-			trainer.run(opts.init_eval or opts.only_eval)
-		except (KeyboardInterrupt, BdbQuit) as e:
-			raise e
-		except Exception as e:
-			dump("exception")
-			raise e
-		else:
-			dump("final")
-
-	def save_meta_info(self, opts, folder: Path):
-		folder.mkdir(parents=True, exist_ok=True)
-
-		with open(folder / "args.yml", "w") as f:
-			pyaml.dump(opts.__dict__, f, sort_keys=True)
-
-
 
-class DefaultFinetuner(_ModelMixin, _DatasetMixin, _TrainerMixin):
+class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._TrainerMixin):
 	""" The default Finetuner gathers together the creations of all needed
 	components and call them in the correct order
 
 	"""
 
 	def __init__(self, opts, *args, **kwargs):
-		super(DefaultFinetuner, self).__init__(*args, **kwargs)
+		super(DefaultFinetuner, self).__init__(opts=opts, *args, **kwargs)
 
 		self.gpu_config(opts)
 		cuda.get_device_from_id(self.device).use()

+ 50 - 0
cvfinetune/finetuner/factory.py

@@ -0,0 +1,50 @@
+import logging
+try:
+	import chainermn
+except Exception as e: #pragma: no cover
+	_CHAINERMN_AVAILABLE = False #pragma: no cover
+else:
+	_CHAINERMN_AVAILABLE = True
+
+from cvfinetune import utils
+from cvfinetune.finetuner.base import DefaultFinetuner
+from cvfinetune.finetuner.mpi import MPIFinetuner
+
+from cvdatasets.utils import pretty_print_dict
+
+class FinetunerFactory(object):
+
+	@classmethod
+	def new(cls, opts, default=DefaultFinetuner, mpi_tuner=MPIFinetuner):
+
+		if getattr(opts, "mpi", False):
+			assert _CHAINERMN_AVAILABLE, "Distributed training is not possible!"
+			msg1 = "MPI enabled. Creating NCCL communicator!"
+			comm = chainermn.create_communicator("pure_nccl")
+			msg2 = f"Rank: {comm.rank}, IntraRank: {comm.intra_rank}, InterRank: {comm.inter_rank}"
+
+			utils.log_messages([msg1, msg2])
+			return cls(mpi_tuner, comm=comm)
+		else:
+			return cls(default)
+
+	def __init__(self, tuner_cls, **kwargs):
+		super(FinetunerFactory, self).__init__()
+
+		self.tuner_cls = tuner_cls
+		self.kwargs = kwargs
+		logging.info(f"Using {self.tuner_cls.__name__} with arguments: {pretty_print_dict(self.kwargs)}")
+
+	def __call__(self, **kwargs):
+		_kwargs = dict(self.kwargs, **kwargs)
+
+		return self.tuner_cls(**_kwargs)
+
+	def get(self, key, default=None):
+		return self.kwargs.get(key, default)
+
+	def __getitem__(self, key):
+		return self.kwargs[key]
+
+	def __setitem__(self, key, value):
+		self.kwargs[key] = value

+ 10 - 0
cvfinetune/finetuner/mixins/__init__.py

@@ -0,0 +1,10 @@
+from cvfinetune.finetuner.mixins.dataset import _DatasetMixin
+from cvfinetune.finetuner.mixins.model import _ModelMixin
+from cvfinetune.finetuner.mixins.trainer import _TrainerMixin
+
+
+__all__ = [
+	"_DatasetMixin",
+	"_ModelMixin",
+	"_TrainerMixin",
+]

+ 108 - 0
cvfinetune/finetuner/mixins/dataset.py

@@ -0,0 +1,108 @@
+import abc
+import logging
+
+from chainer_addons.models import PrepareType
+from cvdatasets import AnnotationType
+from cvdatasets.dataset.image import Size
+from cvdatasets.utils import new_iterator
+from functools import partial
+
+
+class _DatasetMixin(abc.ABC):
+	"""
+		This mixin is responsible for annotation loading and for
+		dataset and iterator creation.
+	"""
+
+	def __init__(self, opts, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
+		super(_DatasetMixin, self).__init__(opts=opts, *args, **kwargs)
+		self.annot = None
+		self.dataset_type = opts.dataset
+		self.dataset_cls = dataset_cls
+		self.dataset_kwargs_factory = dataset_kwargs_factory
+
+	@property
+	def n_classes(self):
+		return self.ds_info.n_classes + self.dataset_cls.label_shift
+
+	@property
+	def data_info(self):
+		assert self.annot is not None, "annot attribute was not set!"
+		return self.annot.info
+
+	@property
+	def ds_info(self):
+		return self.data_info.DATASETS[self.dataset_type]
+
+	def new_dataset(self, opts, size, part_size, subset):
+		"""Creates a dataset for a specific subset and certain options"""
+		if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
+			kwargs = self.dataset_kwargs_factory(opts, subset)
+		else:
+			kwargs = dict()
+
+		kwargs = dict(kwargs,
+			subset=subset,
+			dataset_cls=self.dataset_cls,
+			prepare=self.prepare,
+			size=size,
+			part_size=part_size,
+			center_crop_on_val=getattr(opts, "center_crop_on_val", False),
+		)
+
+
+		ds = self.annot.new_dataset(**kwargs)
+		logging.info("Loaded {} images".format(len(ds)))
+		return ds
+
+
+	def init_annotations(self, opts):
+		"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
+
+		self.annot = AnnotationType.new_annotation(opts, load_strict=False)
+		self.dataset_cls.label_shift = opts.label_shift
+
+
+	def init_datasets(self, opts):
+
+		size = Size(opts.input_size)
+		part_size = getattr(opts, "parts_input_size", None)
+		part_size = size if part_size is None else Size(part_size)
+
+		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
+			swap_channels=opts.swap_channels,
+			keep_ratio=getattr(opts, "center_crop_on_val", False),
+		)
+
+		logging.info(" ".join([
+			f"Created {self.model.__class__.__name__} model",
+			f"with \"{opts.prepare_type}\" prepare function."
+		]))
+
+		logging.info(" ".join([
+			f"Image input size: {size}",
+			f"Image parts input size: {part_size}",
+		]))
+
+		self.train_data = self.new_dataset(opts, size, part_size, "train")
+		self.val_data = self.new_dataset(opts, size, part_size, "test")
+
+	def init_iterators(self, opts):
+		"""Creates training and validation iterators from training and validation datasets"""
+
+		kwargs = dict(n_jobs=opts.n_jobs, batch_size=opts.batch_size)
+
+		if hasattr(self.train_data, "new_iterator"):
+			self.train_iter, _ = self.train_data.new_iterator(**kwargs)
+		else:
+			self.train_iter, _ = new_iterator(self.train_data, **kwargs)
+
+		if hasattr(self.val_data, "new_iterator"):
+			self.val_iter, _ = self.val_data.new_iterator(**kwargs,
+				repeat=False, shuffle=False
+			)
+		else:
+			self.val_iter, _ = new_iterator(self.val_data,
+				**kwargs, repeat=False, shuffle=False
+			)
+

+ 159 - 0
cvfinetune/finetuner/mixins/model.py

@@ -0,0 +1,159 @@
+import abc
+import chainer.functions as F
+import logging
+
+from chainer.optimizer_hooks import Lasso
+from chainer.optimizer_hooks import WeightDecay
+from chainer_addons.functions import smoothed_cross_entropy
+from chainer_addons.models import ModelType
+from chainer_addons.training import optimizer
+from chainer_addons.training import optimizer_hooks
+from cvdatasets.dataset.image import Size
+from cvdatasets.utils import pretty_print_dict
+from functools import partial
+from pathlib import Path
+
+
+class _ModelMixin(abc.ABC):
+	"""
+		This mixin is responsible for optimizer creation, model creation,
+		model wrapping around a classifier and model weights loading.
+	"""
+
+	def __init__(self, opts, classifier_cls, classifier_kwargs={}, model_kwargs={}, *args, **kwargs):
+		super(_ModelMixin, self).__init__(opts=opts, *args, **kwargs)
+		self.classifier_cls = classifier_cls
+		self.classifier_kwargs = classifier_kwargs
+		self.model_type = opts.model_type
+		self.model_kwargs = model_kwargs
+
+
+	@property
+	def model_info(self):
+		return self.data_info.MODELS[self.model_type]
+
+	def wrap_model(self, opts):
+
+		clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
+
+		self.clf = clf_class(
+			model=self.model,
+			loss_func=self._loss_func(opts),
+			**kwargs)
+
+		logging.info(" ".join([
+			f"Wrapped the model around {clf_class.__name__}",
+			f"with kwargs: {pretty_print_dict(kwargs)}",
+		]))
+
+	def _loss_func(self, opts):
+		if getattr(opts, "l1_loss", False):
+			return F.hinge
+
+		elif getattr(opts, "label_smoothing", 0) >= 0:
+			assert getattr(opts, "label_smoothing", 0) < 1, \
+				"Label smoothing factor must be less than 1!"
+			return partial(smoothed_cross_entropy,
+				N=self.n_classes,
+				eps=getattr(opts, "label_smoothing", 0))
+		else:
+			return F.softmax_cross_entropy
+
+	def init_optimizer(self, opts):
+		"""Creates an optimizer for the classifier """
+		if not hasattr(opts, "optimizer"):
+			self.opt = None
+			return
+
+		opt_kwargs = {}
+		if opts.optimizer == "rmsprop":
+			opt_kwargs["alpha"] = 0.9
+
+		self.opt = optimizer(opts.optimizer,
+			self.clf,
+			opts.learning_rate,
+			decay=0, gradient_clipping=False, **opt_kwargs
+		)
+
+		if opts.decay > 0:
+			reg_kwargs = {}
+			if opts.l1_loss:
+				reg_cls = Lasso
+
+			elif opts.pooling == "alpha":
+				reg_cls = optimizer_hooks.SelectiveWeightDecay
+				reg_kwargs["selection"] = check_param_for_decay
+
+			else:
+				reg_cls = WeightDecay
+
+			logging.info(f"Adding {reg_cls.__name__} ({opts.decay:e})")
+			self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
+
+		if getattr(opts, "only_head", False):
+			assert not getattr(opts, "recurrent", False), \
+				"Recurrent classifier is not supported with only_head option!"
+
+			logging.warning("========= Fine-tuning only classifier layer! =========")
+			enable_only_head(self.clf)
+
+	def init_model(self, opts):
+		"""creates backbone CNN model. This model is wrapped around the classifier later"""
+
+		if self.model_type.startswith("cv2_"):
+			model_type = args.model_type.split("cv2_")[-1]
+		else:
+			model_type = self.model_info.class_key
+
+		self.model = ModelType.new(
+			model_type=model_type,
+			input_size=Size(opts.input_size),
+			**self.model_kwargs,
+		)
+
+	def load_model_weights(self, args):
+		if getattr(args, "from_scratch", False):
+			logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
+			loader = self.model.reinitialize_clf
+			self.weights = None
+		else:
+			if args.load:
+				self.weights = args.load
+				msg = "Loading already fine-tuned weights from \"{}\""
+				loader_func = self.model.load_for_inference
+			else:
+				if args.weights:
+					msg = "Loading custom pre-trained weights \"{}\""
+					self.weights = args.weights
+
+				else:
+					msg = "Loading default pre-trained weights \"{}\""
+					self.weights = str(Path(
+						self.data_info.BASE_DIR,
+						self.data_info.MODEL_DIR,
+						self.model_info.folder,
+						self.model_info.weights
+					))
+
+				loader_func = self.model.load_for_finetune
+
+			logging.info(msg.format(self.weights))
+			kwargs = dict(
+				weights=self.weights,
+				strict=args.load_strict,
+				path=args.load_path,
+				headless=args.headless,
+			)
+			loader = partial(loader_func, **kwargs)
+
+		feat_size = self.model.meta.feature_size
+
+		if hasattr(self.clf, "output_size"):
+			feat_size = self.clf.output_size
+
+		if hasattr(self.clf, "loader"):
+			loader = self.clf.loader(loader)
+
+		logging.info(f"Part features size after encoding: {feat_size}")
+		loader(n_classes=self.n_classes, feat_size=feat_size)
+		self.clf.cleargrads()

+ 92 - 0
cvfinetune/finetuner/mixins/trainer.py

@@ -0,0 +1,92 @@
+import abc
+import logging
+import pyaml
+
+from bdb import BdbQuit
+from chainer.serializers import save_npz
+from chainer.training import extensions
+from cvdatasets.utils import pretty_print_dict
+from pathlib import Path
+
+
+class _TrainerMixin(abc.ABC):
+	"""This mixin is responsible for updater, evaluator and trainer creation.
+	Furthermore, it implements the run method
+	"""
+
+	def __init__(self, opts, updater_cls, updater_kwargs={}, *args, **kwargs):
+		super(_TrainerMixin, self).__init__(*args, **kwargs)
+		self.updater_cls = updater_cls
+		self.updater_kwargs = updater_kwargs
+
+	def init_updater(self):
+		"""Creates an updater from training iterator and the optimizer."""
+
+		if self.opt is None:
+			self.updater = None
+			return
+
+		self.updater = self.updater_cls(
+			iterator=self.train_iter,
+			optimizer=self.opt,
+			device=self.device,
+			**self.updater_kwargs,
+		)
+		logging.info(" ".join([
+			f"Using single GPU: {self.device}.",
+			f"{self.updater_cls.__name__} is initialized",
+			f"with following kwargs: {pretty_print_dict(self.updater_kwargs)}"
+			])
+		)
+
+	def init_evaluator(self, default_name="val"):
+		"""Creates evaluation extension from validation iterator and the classifier."""
+
+		self.evaluator = extensions.Evaluator(
+			iterator=self.val_iter,
+			target=self.clf,
+			device=self.device,
+			progress_bar=True
+		)
+
+		self.evaluator.default_name = default_name
+
+	def _new_trainer(self, trainer_cls, opts, *args, **kwargs):
+		return trainer_cls(
+			opts=opts,
+			updater=self.updater,
+			evaluator=self.evaluator,
+			*args, **kwargs
+		)
+
+	def run(self, trainer_cls, opts, *args, **kwargs):
+
+		trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
+
+		self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
+
+		logging.info("Snapshotting is {}abled".format("dis" if opts.no_snapshot else "en"))
+
+		def dump(suffix):
+			if opts.only_eval or opts.no_snapshot:
+				return
+
+			save_npz(Path(trainer.out, f"clf_{suffix}.npz"), self.clf)
+			save_npz(Path(trainer.out, f"model_{suffix}.npz"), self.model)
+
+		try:
+			trainer.run(opts.init_eval or opts.only_eval)
+		except (KeyboardInterrupt, BdbQuit) as e:
+			raise e
+		except Exception as e:
+			dump("exception")
+			raise e
+		else:
+			dump("final")
+
+	def save_meta_info(self, opts, folder: Path):
+		folder.mkdir(parents=True, exist_ok=True)
+
+		with open(folder / "args.yml", "w") as f:
+			pyaml.dump(opts.__dict__, f, sort_keys=True)
+