Browse Source

some functions were missing

Dimitri Korsch 5 years ago
parent
commit
26392d602a
2 changed files with 13 additions and 13 deletions
  1. 0 12
      cvfinetune/finetuner/base.py
  2. 13 1
      cvfinetune/finetuner/mixins/model.py

+ 0 - 12
cvfinetune/finetuner/base.py

@@ -4,18 +4,6 @@ from chainer.backends import cuda
 
 
 from cvfinetune.finetuner import mixins
 from cvfinetune.finetuner import mixins
 
 
-def check_param_for_decay(param):
-	return param.name != "alpha"
-
-def enable_only_head(chain: chainer.Chain):
-	if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
-		chain.enable_only_head()
-
-	else:
-		chain.disable_update()
-		chain.fc.enable_update()
-
-
 class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._TrainerMixin):
 class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._TrainerMixin):
 	""" The default Finetuner gathers together the creations of all needed
 	""" The default Finetuner gathers together the creations of all needed
 	components and call them in the correct order
 	components and call them in the correct order

+ 13 - 1
cvfinetune/finetuner/mixins/model.py

@@ -1,7 +1,8 @@
 import abc
 import abc
-import chainer.functions as F
+import chainer
 import logging
 import logging
 
 
+from chainer import functions as F
 from chainer.optimizer_hooks import Lasso
 from chainer.optimizer_hooks import Lasso
 from chainer.optimizer_hooks import WeightDecay
 from chainer.optimizer_hooks import WeightDecay
 from chainer_addons.functions import smoothed_cross_entropy
 from chainer_addons.functions import smoothed_cross_entropy
@@ -18,6 +19,17 @@ from pathlib import Path
 from typing import Callable
 from typing import Callable
 from typing import Tuple
 from typing import Tuple
 
 
+def check_param_for_decay(param):
+	return param.name != "alpha"
+
+def enable_only_head(chain: chainer.Chain):
+	if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
+		chain.enable_only_head()
+
+	else:
+		chain.disable_update()
+		chain.fc.enable_update()
+
 
 
 class _ModelMixin(abc.ABC):
 class _ModelMixin(abc.ABC):
 	"""
 	"""