|
@@ -32,6 +32,14 @@ from pathlib import Path
|
|
|
def check_param_for_decay(param):
|
|
|
return param.name != "alpha"
|
|
|
|
|
|
+def enable_only_head(chain: chainer.Chain):
|
|
|
+ if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
|
|
|
+ chain.enable_only_head()
|
|
|
+
|
|
|
+ else:
|
|
|
+ chain.disable_update()
|
|
|
+ chain.fc.enable_update()
|
|
|
+
|
|
|
class _ModelMixin(abc.ABC):
|
|
|
"""This mixin is responsible for optimizer creation, model creation,
|
|
|
model wrapping around a classifier and model weights loading.
|
|
@@ -102,11 +110,11 @@ class _ModelMixin(abc.ABC):
|
|
|
self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
|
|
|
|
|
|
if getattr(opts, "only_head", False):
|
|
|
- assert not opts.recurrent, "FIX ME! Not supported yet!"
|
|
|
+ assert not getattr(opts, "recurrent", False), \
|
|
|
+ "Recurrent classifier is not supported with only_head option!"
|
|
|
|
|
|
logging.warning("========= Fine-tuning only classifier layer! =========")
|
|
|
- self.model.disable_update()
|
|
|
- self.model.fc.enable_update()
|
|
|
+ enable_only_head(self.clf)
|
|
|
|
|
|
def init_model(self, opts):
|
|
|
"""creates backbone CNN model. This model is wrapped around the classifier later"""
|
|
@@ -185,20 +193,15 @@ class _DatasetMixin(abc.ABC):
|
|
|
else:
|
|
|
kwargs = dict()
|
|
|
|
|
|
- kwargs.update(dict(
|
|
|
+ kwargs = dict(kwargs,
|
|
|
subset=subset,
|
|
|
dataset_cls=self.dataset_cls,
|
|
|
- ))
|
|
|
-
|
|
|
-
|
|
|
- if not getattr(opts, "only_head", False):
|
|
|
- kwargs.update(dict(
|
|
|
- prepare=self.prepare,
|
|
|
- size=size,
|
|
|
- part_size=part_size,
|
|
|
- center_crop_on_val=getattr(opts, "center_crop_on_val", False),
|
|
|
+ prepare=self.prepare,
|
|
|
+ size=size,
|
|
|
+ part_size=part_size,
|
|
|
+ center_crop_on_val=getattr(opts, "center_crop_on_val", False),
|
|
|
+ )
|
|
|
|
|
|
- ))
|
|
|
|
|
|
ds = self.annot.new_dataset(**kwargs)
|
|
|
logging.info("Loaded {} images".format(len(ds)))
|
|
@@ -214,9 +217,6 @@ class _DatasetMixin(abc.ABC):
|
|
|
self.ds_info = self.data_info.DATASETS[opts.dataset]
|
|
|
# self.part_info = self.data_info.PART_TYPES[opts.parts]
|
|
|
|
|
|
- if getattr(opts, "only_head", False):
|
|
|
- self.annot.feature_model = opts.model_type
|
|
|
-
|
|
|
self.dataset_cls.label_shift = opts.label_shift
|
|
|
|
|
|
|