Ver Fonte

some functions were missing

Dimitri Korsch há 4 anos atrás
pai
commit
26392d602a
2 ficheiros alterados com 13 adições e 13 exclusões
  1. 0 12
      cvfinetune/finetuner/base.py
  2. 13 1
      cvfinetune/finetuner/mixins/model.py

+ 0 - 12
cvfinetune/finetuner/base.py

@@ -4,18 +4,6 @@ from chainer.backends import cuda
 
 
 from cvfinetune.finetuner import mixins
 from cvfinetune.finetuner import mixins
 
 
-def check_param_for_decay(param):
-	return param.name != "alpha"
-
-def enable_only_head(chain: chainer.Chain):
-	if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
-		chain.enable_only_head()
-
-	else:
-		chain.disable_update()
-		chain.fc.enable_update()
-
-
 class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._TrainerMixin):
 class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._TrainerMixin):
 	""" The default Finetuner gathers together the creations of all needed
 	""" The default Finetuner gathers together the creations of all needed
 	components and call them in the correct order
 	components and call them in the correct order

+ 13 - 1
cvfinetune/finetuner/mixins/model.py

@@ -1,7 +1,8 @@
 import abc
 import abc
-import chainer.functions as F
+import chainer
 import logging
 import logging
 
 
+from chainer import functions as F
 from chainer.optimizer_hooks import Lasso
 from chainer.optimizer_hooks import Lasso
 from chainer.optimizer_hooks import WeightDecay
 from chainer.optimizer_hooks import WeightDecay
 from chainer_addons.functions import smoothed_cross_entropy
 from chainer_addons.functions import smoothed_cross_entropy
@@ -18,6 +19,17 @@ from pathlib import Path
 from typing import Callable
 from typing import Callable
 from typing import Tuple
 from typing import Tuple
 
 
+def check_param_for_decay(param):
+	return param.name != "alpha"
+
+def enable_only_head(chain: chainer.Chain):
+	if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
+		chain.enable_only_head()
+
+	else:
+		chain.disable_update()
+		chain.fc.enable_update()
+
 
 
 class _ModelMixin(abc.ABC):
 class _ModelMixin(abc.ABC):
 	"""
 	"""