瀏覽代碼

updated code for only_head trainings

Dimitri Korsch 4 年之前
父節點
當前提交
8d09a95de0
共有 2 個文件被更改,包括 26 次插入17 次删除
  1. 9 0
      cvfinetune/classifier.py
  2. 17 17
      cvfinetune/finetuner/base.py

+ 9 - 0
cvfinetune/classifier.py

@@ -25,6 +25,10 @@ class Classifier(C):
 	def output_size(self):
 		return self.feat_size
 
+	def enable_only_head(self):
+		self.model.disable_update()
+		self.model.fc.enable_update()
+
 
 class SeparateModelClassifier(Classifier):
 	"""Classifier, that holds two separate models"""
@@ -67,3 +71,8 @@ class SeparateModelClassifier(Classifier):
 				feat_size=self.feat_size)
 
 		return inner
+
+	def enable_only_head(self):
+		super(SeparateModelClassifier, self).enable_only_head()
+		self.separate_model.disable_update()
+		self.separate_model.fc.enable_update()

+ 17 - 17
cvfinetune/finetuner/base.py

@@ -32,6 +32,14 @@ from pathlib import Path
 def check_param_for_decay(param):
 	return param.name != "alpha"
 
+def enable_only_head(chain: chainer.Chain):
+	if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
+		chain.enable_only_head()
+
+	else:
+		chain.disable_update()
+		chain.fc.enable_update()
+
 class _ModelMixin(abc.ABC):
 	"""This mixin is responsible for optimizer creation, model creation,
 	model wrapping around a classifier and model weights loading.
@@ -102,11 +110,11 @@ class _ModelMixin(abc.ABC):
 			self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
 
 		if getattr(opts, "only_head", False):
-			assert not opts.recurrent, "FIX ME! Not supported yet!"
+			assert not getattr(opts, "recurrent", False), \
+				"Recurrent classifier is not supported with only_head option!"
 
 			logging.warning("========= Fine-tuning only classifier layer! =========")
-			self.model.disable_update()
-			self.model.fc.enable_update()
+			enable_only_head(self.clf)
 
 	def init_model(self, opts):
 		"""creates backbone CNN model. This model is wrapped around the classifier later"""
@@ -185,20 +193,15 @@ class _DatasetMixin(abc.ABC):
 		else:
 			kwargs = dict()
 
-		kwargs.update(dict(
+		kwargs = dict(kwargs,
 			subset=subset,
 			dataset_cls=self.dataset_cls,
-		))
-
-
-		if not getattr(opts, "only_head", False):
-			kwargs.update(dict(
-				prepare=self.prepare,
-				size=size,
-				part_size=part_size,
-				center_crop_on_val=getattr(opts, "center_crop_on_val", False),
+			prepare=self.prepare,
+			size=size,
+			part_size=part_size,
+			center_crop_on_val=getattr(opts, "center_crop_on_val", False),
+		)
 
-			))
 
 		ds = self.annot.new_dataset(**kwargs)
 		logging.info("Loaded {} images".format(len(ds)))
@@ -214,9 +217,6 @@ class _DatasetMixin(abc.ABC):
 		self.ds_info = self.data_info.DATASETS[opts.dataset]
 		# self.part_info = self.data_info.PART_TYPES[opts.parts]
 
-		if getattr(opts, "only_head", False):
-			self.annot.feature_model = opts.model_type
-
 		self.dataset_cls.label_shift = opts.label_shift