Эх сурвалжийг харах

copied some code from the older version

Dimitri Korsch 6 жил өмнө
parent
commit
65c5dc7fbf

+ 46 - 0
finetune/classifier.py

@@ -0,0 +1,46 @@
+import chainer
+import chainer.functions as F
+import chainer.links as L
+
+from chainer_addons.models.base import BaseClassifier
+import logging
+
+class SeparateModelClassifier(BaseClassifier):
+	"""Classifier, that holds two separate models"""
+	def __init__(self, *args, **kwargs):
+		super(SeparateModelClassifier, self).__init__(*args, **kwargs)
+
+		with self.init_scope():
+			self.init_separate_model()
+
+	def init_separate_model(self):
+
+		assert hasattr(self, "model"), \
+			"This classifiert has no \"model\" attribute!"
+
+		if hasattr(self, "separate_model"):
+			logging.warn("Global Model already initialized! Skipping further execution!")
+			return
+
+		self.separate_model = self.model.copy(mode="copy")
+
+	def loader(self, model_loader):
+
+		def inner(n_classes, feat_size):
+			# use the given feature size here
+			model_loader(n_classes=n_classes, feat_size=feat_size)
+
+			# use the given feature size first ...
+			self.separate_model.reinitialize_clf(
+				n_classes=n_classes,
+				feat_size=feat_size)
+
+			# then copy model params ...
+			self.separate_model.copyparams(self.model)
+
+			# now use the default feature size to re-init the classifier
+			self.separate_model.reinitialize_clf(
+				n_classes=n_classes,
+				feat_size=self.feat_size)
+
+		return inner

+ 44 - 0
finetune/dataset.py

@@ -0,0 +1,44 @@
+import numpy as np
+import abc
+
+from chainer_addons.dataset import AugmentationMixin
+from chainer_addons.dataset import PreprocessMixin
+
+from cvdatasets.dataset import AnnotationsReadMixin
+from cvdatasets.dataset import RevealedPartMixin
+from cvdatasets.dataset import IteratorMixin
+
+class _pre_augmentation_mixin(abc.ABC):
+	""" This mixin discards the parts from the ImageWrapper object
+	and shifts the labels
+	"""
+
+	label_shift = 1
+
+	def get_example(self, i):
+		im_obj = super(_pre_augmentation_mixin, self).get_example(i)
+		im, parts, lab = im_obj.as_tuple()
+		return im, lab + self.label_shift
+
+class _base_mixin(abc.ABC):
+	""" This mixin converts images,that are in range
+	[0..1] to the range [-1..1]
+	"""
+
+	def get_example(self, i):
+		im, lab = super(_base_mixin, self).get_example(i)
+		if isinstance(im, list):
+			im = np.array(im)
+		return im * 2 - 1, lab
+
+
+class BaseDataset(_base_mixin,
+	# augmentation and preprocessing
+	AugmentationMixin, PreprocessMixin,
+	_pre_augmentation_mixin,
+	# random uniform region selection
+	RevealedPartMixin,
+	# reads image
+	AnnotationsReadMixin,
+	IteratorMixin):
+	"""Commonly used dataset constellation"""

+ 0 - 0
finetune/finetuner/__init__.py


+ 339 - 0
finetune/finetuner/base.py

@@ -0,0 +1,339 @@
+import chainer
+import chainer.functions as F
+import numpy as np
+
+import abc
+import logging
+
+from chainer.backends import cuda
+from chainer.optimizer_hooks import Lasso
+from chainer.optimizer_hooks import WeightDecay
+from chainer.training import StandardUpdater, extensions
+from chainer.serializers import save_npz
+
+from chainer_addons.models import ModelType
+from chainer_addons.models import Classifier
+from chainer_addons.models import PrepareType
+from chainer_addons.training import optimizer, optimizer_hooks
+from chainer_addons.functions import smoothed_cross_entropy
+
+from cvdatasets.annotations import AnnotationType
+from cvdatasets.utils import new_iterator
+
+from finetune.core.classifier import FVEMixin
+from finetune.core.classifier import BasePartsClassifier
+from finetune.core.training import AlphaPoolingTrainer
+from finetune.core.training import Trainer
+from finetune.core.dataset import Dataset
+
+from functools import partial
+from os.path import join
+
+from bdb import BdbQuit
+
+
+def check_param_for_decay(param):
+	return param.name != "alpha"
+
+
+class _ModelMixin(abc.ABC):
+	"""This mixin is responsible for optimizer creation, model creation,
+	model wrapping around a classifier and model weights loading.
+	"""
+
+	def __init__(self, classifier_cls, classifier_kwargs, *args, **kwargs):
+		super(_ModelMixin, self).__init__(*args, **kwargs)
+		self.classifier_kwargs = classifier_kwargs
+		self.classifier_cls = classifier_cls
+
+	def wrap_model(self, opts):
+
+		clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
+
+		self.clf = clf_class(
+			model=self.model,
+			loss_func=self._loss_func(opts),
+			**kwargs)
+
+		kwargs_info = " ".join([f"{key}={value}" for key, value in kwargs.items()])
+		logging.info(" ".join([
+			f"Wrapped the model around {clf_class.__name__}",
+			f"with kwargs: {kwargs_info}",
+		]))
+
+	def _loss_func(self, opts):
+		if opts.l1_loss:
+			return F.hinge
+
+		elif opts.label_smoothing >= 0:
+			assert opts.label_smoothing < 1, \
+				"Label smoothing factor must be less than 1!"
+			return partial(smoothed_cross_entropy,
+				N=self.n_classes,
+				eps=opts.label_smoothing)
+		else:
+			return F.softmax_cross_entropy
+
+	def init_optimizer(self, opts):
+		"""Creates an optimizer for the classifier """
+
+		opt_kwargs = {}
+		if opts.optimizer == "rmsprop":
+			opt_kwargs["alpha"] = 0.9
+
+		self.opt = optimizer(opts.optimizer,
+			self.clf,
+			opts.learning_rate,
+			decay=0, gradient_clipping=False, **opt_kwargs
+		)
+
+		if opts.decay:
+			reg_kwargs = {}
+			if opts.l1_loss:
+				reg_cls = Lasso
+
+			elif opts.pooling == "alpha":
+				reg_cls = optimizer_hooks.SelectiveWeightDecay
+				reg_kwargs["selection"] = check_param_for_decay
+
+			else:
+				reg_cls = WeightDecay
+
+			logging.info(f"Adding {reg_cls.__name__} ({opts.decay:e})")
+			self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
+
+		if opts.only_head:
+			assert not opts.recurrent, "FIX ME! Not supported yet!"
+
+			logging.warning("========= Fine-tuning only classifier layer! =========")
+			self.model.disable_update()
+			self.model.fc.enable_update()
+
+	def init_model(self, opts):
+		"""creates backbone CNN model. This model is wrapped around the classifier later"""
+
+		self.model = ModelType.new(
+			model_type=self.model_info.class_key,
+			input_size=opts.input_size,
+			pooling=opts.pooling,
+			pooling_params=dict(
+				init_alpha=opts.init_alpha,
+				output_dim=8192,
+				normalize=opts.normalize),
+			aux_logits=False
+		)
+
+	def load_model_weights(self, args):
+		if args.from_scratch:
+			logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
+			loader = self.model.reinitialize_clf
+			self.weights = None
+		else:
+			if args.load:
+				self.weights = args.load
+				logging.info("Loading already fine-tuned weights from \"{}\"".format(self.weights))
+				loader = partial(self.model.load_for_inference, weights=self.weights)
+			else:
+				self.weights = join(
+					self.data_info.BASE_DIR,
+					self.data_info.MODEL_DIR,
+					self.model_info.folder,
+					self.model_info.weights
+				)
+				logging.info("Loading pre-trained weights \"{}\"".format(self.weights))
+				loader = partial(self.model.load_for_finetune, weights=self.weights)
+
+
+		if hasattr(self.model.pool, "output_dim") and self.model.pool.output_dim is not None:
+			feat_size = self.model.pool.output_dim
+
+		elif isinstance(self.clf, (BasePartsClassifier, FVEMixin)):
+			feat_size = self.clf.outsize
+
+		else:
+			feat_size = self.model.meta.feature_size
+
+		if hasattr(self.clf, "loader"):
+			loader = self.clf.loader(loader)
+
+		logging.info(f"Part features size after encoding: {feat_size}")
+		loader(n_classes=self.n_classes, feat_size=feat_size)
+		self.clf.cleargrads()
+
+class _DatasetMixin(abc.ABC):
+	"""
+		This mixin is responsible for annotation loading and for
+		dataset and iterator creation.
+	"""
+
+	def __init__(self, dataset_cls, *args, **kwargs):
+		super(_DatasetMixin, self).__init__(*args, **kwargs)
+		self.dataset_cls = dataset_cls
+
+	@property
+	def n_classes(self):
+		return self.part_info.n_classes + self.dataset_cls.label_shift
+
+	def new_dataset(self, opts, size, subset, augment):
+		"""Creates a dataset for a specific subset and certain options"""
+		kwargs = dict(
+			subset=subset,
+			dataset_cls=self.dataset_cls,
+		)
+		if opts.use_parts:
+			kwargs.update(dict(
+				no_glob=opts.no_global,
+			))
+
+		if not opts.only_head:
+			kwargs.update(dict(
+				preprocess=self.prepare,
+				augment=augment,
+				size=size,
+				center_crop_on_val=not opts.no_center_crop_on_val,
+
+				# return_part_crops=args.use_parts,
+			))
+
+		d = self.annot.new_dataset(**kwargs)
+		logging.info("Loaded {} images".format(len(d)))
+		logging.info("Data augmentation is {}abled".format("en" if augment else "dis"))
+		logging.info("Global feature is {}used".format("not " if opts.no_global else ""))
+		return d
+
+	def init_annotations(self, opts):
+		"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
+
+		annot_cls = AnnotationType.get(opts.dataset).value
+		self.annot = annot_cls(opts.data, opts.parts)
+
+		self.data_info = self.annot.info
+		self.model_info = self.data_info.MODELS[opts.model_type]
+		self.part_info = self.data_info.PARTS[opts.parts]
+
+		if opts.only_head:
+			self.annot.feature_model = opts.model_type
+
+	def init_datasets(self, opts):
+
+		self.dataset_cls.label_shift = opts.label_shift
+
+		size = 112 if opts.recurrent else self.model.meta.input_size
+
+		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
+			swap_channels=opts.swap_channels,
+			keep_ratio=not opts.no_center_crop_on_val,
+		)
+
+		logging.info(" ".join([
+			f"Created {self.model.__class__.__name__} model",
+			f"with \"{opts.prepare_type}\" prepare function.",
+			f"Image input size: {size}",
+		]))
+
+		self.train_data = self.new_dataset(opts, size, "train", True)
+		self.val_data = self.new_dataset(opts, size, "test", False)
+
+	def init_iterators(self, opts):
+		"""Creates training and validation iterators from training and validation datasets"""
+
+		self.train_iter, _ = new_iterator(self.train_data,
+			opts.n_jobs, opts.batch_size
+		)
+
+		self.val_iter, _ = new_iterator(self.val_data,
+			opts.n_jobs, opts.batch_size,
+			repeat=False, shuffle=False
+		)
+
+class _TrainerMixin(abc.ABC):
+	"""This mixin is responsible for updater, evaluator and trainer creation.
+	Furthermore, it implements the run method
+	"""
+
+	def init_updater(self, updater_cls=StandardUpdater, updater_kwargs={}):
+		"""Creates an updater from training iterator and the optimizer."""
+
+		self.updater = updater_cls(
+			iterator=self.train_iter,
+			optimizer=self.opt,
+			device=self.device,
+			**updater_kwargs,
+		)
+		logging.info(f"Using single GPU: {self.device}. {updater_cls.__name__} is initialized.")
+
+	def init_evaluator(self, default_name="val"):
+		"""Creates evaluation extension from validation iterator and the classifier."""
+
+		self.evaluator = extensions.Evaluator(
+			iterator=self.val_iter,
+			target=self.clf,
+			device=self.device)
+
+		self.evaluator.default_name = default_name
+
+	def run(self, opts, ex, no_observe=False):
+
+		trainer_cls = AlphaPoolingTrainer if opts.pooling=="alpha" else Trainer
+		trainer = trainer_cls(
+			ex=ex,
+			opts=opts,
+			updater=self.updater,
+			evaluator=self.evaluator,
+			weights=self.weights,
+			no_observe=no_observe
+		)
+		def dump(suffix):
+			if opts.only_eval or opts.no_snapshot:
+				return
+
+			save_npz(join(trainer.out,
+				"clf_{}.npz".format(suffix)), self.clf)
+			save_npz(join(trainer.out,
+				"model_{}.npz".format(suffix)), self.model)
+
+		try:
+			trainer.run(init_eval=opts.init_eval or opts.only_eval)
+		except (KeyboardInterrupt, BdbQuit) as e:
+			raise e
+		except Exception as e:
+			dump("exception")
+			raise e
+		else:
+			dump("final")
+
+
+class DefaultFinetuner(_ModelMixin, _DatasetMixin, _TrainerMixin):
+	""" The default Finetuner gathers together the creations of all needed
+	components and call them in the correct order
+
+	"""
+
+
+	def __init__(self, opts, *args, **kwargs):
+		super(BaseFinetuner, self).__init__()
+
+		self.gpu_config(opts, *args, **kwargs)
+
+	def gpu_config(self, opts):
+		if -1 in opts.gpu:
+			self.device = -1
+		else:
+			self.device = opts.gpu[0]
+		cuda.get_device_from_id(self.device).use()
+
+	def setup(self, opts, updater_cls, updater_kwargs):
+
+		self.init_annotations(opts)
+		self.init_model(opts)
+
+		self.init_datasets(opts)
+		self.init_iterators(opts)
+
+		self.wrap_model(opts)
+		self.load_model_weights(opts)
+
+		self.init_optimizer(opts)
+		self.init_updater(updater_cls=updater_cls, updater_kwargs=updater_kwargs)
+		self.init_evaluator()
+

+ 54 - 0
finetune/finetuner/mpi.py

@@ -0,0 +1,54 @@
+from .base import BaseFinetuner
+
+class MPIFinetuner(BaseFinetuner):
+
+	@property
+	def mpi(self):
+		return self.comm is not None
+
+	@property
+	def mpi_main_process(self):
+		return not self.mpi or self.comm.rank == 0
+
+	def gpu_config(self, opts, comm=None):
+		super(MPIFinetuner, self).gpu_config(opts)
+
+		self.comm = comm
+		if self.mpi:
+			self.device = opts.gpu[self.comm.rank]
+
+			# self.device += self.comm.intra_rank
+
+	def scatter_datasets(self):
+		if self.mpi:
+			from chainermn import scatter_dataset as scatter
+			self.train_data = scatter(self.train_data, self.comm)
+			self.val_data = scatter(self.val_data, self.comm)
+
+	def init_datasets(self, *args, **kwargs):
+
+		if not self.mpi_main_process:
+			self.train_data, self.val_data = None, None
+			return
+
+		super(MPIFinetuner, self).init_datasets(*args, **kwargs)
+
+		self.scatter_datasets()
+
+	def init_optimizer(self, opts):
+		super(MPIFinetuner, self).init_optimizer(opts)
+
+		if self.mpi:
+			import chainermn
+			self.opt = chainermn.create_multi_node_optimizer(self.opt, self.comm)
+
+	def init_evaluator(self):
+		super(MPIFinetuner, self).init_evaluator()
+
+		if self.mpi:
+			import chainermn
+			self.evaluator = chainermn.create_multi_node_evaluator(
+				self.evaluator, self.comm)
+
+	def run(self, opts, ex):
+		super(MPIFinetuner, self).run(opts, ex, no_observe=not self.mpi_main_process)