Ver código fonte

some functions were missing

Dimitri Korsch 4 anos atrás
pai
commit
26392d602a
2 arquivos alterados com 13 adições e 13 exclusões
  1. 0 12
      cvfinetune/finetuner/base.py
  2. 13 1
      cvfinetune/finetuner/mixins/model.py

+ 0 - 12
cvfinetune/finetuner/base.py

@@ -4,18 +4,6 @@ from chainer.backends import cuda
 
 from cvfinetune.finetuner import mixins
 
-def check_param_for_decay(param):
-	return param.name != "alpha"
-
-def enable_only_head(chain: chainer.Chain):
-	if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
-		chain.enable_only_head()
-
-	else:
-		chain.disable_update()
-		chain.fc.enable_update()
-
-
 class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._TrainerMixin):
 	""" The default Finetuner gathers together the creations of all needed
 	components and call them in the correct order

+ 13 - 1
cvfinetune/finetuner/mixins/model.py

@@ -1,7 +1,8 @@
 import abc
-import chainer.functions as F
+import chainer
 import logging
 
+from chainer import functions as F
 from chainer.optimizer_hooks import Lasso
 from chainer.optimizer_hooks import WeightDecay
 from chainer_addons.functions import smoothed_cross_entropy
@@ -18,6 +19,17 @@ from pathlib import Path
 from typing import Callable
 from typing import Tuple
 
+def check_param_for_decay(param):
+	return param.name != "alpha"
+
+def enable_only_head(chain: chainer.Chain):
+	if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
+		chain.enable_only_head()
+
+	else:
+		chain.disable_update()
+		chain.fc.enable_update()
+
 
 class _ModelMixin(abc.ABC):
 	"""