|
@@ -1,7 +1,8 @@
|
|
|
import abc
|
|
|
-import chainer.functions as F
|
|
|
+import chainer
|
|
|
import logging
|
|
|
|
|
|
+from chainer import functions as F
|
|
|
from chainer.optimizer_hooks import Lasso
|
|
|
from chainer.optimizer_hooks import WeightDecay
|
|
|
from chainer_addons.functions import smoothed_cross_entropy
|
|
@@ -18,6 +19,17 @@ from pathlib import Path
|
|
|
from typing import Callable
|
|
|
from typing import Tuple
|
|
|
|
|
|
+def check_param_for_decay(param):
|
|
|
+ return param.name != "alpha"
|
|
|
+
|
|
|
+def enable_only_head(chain: chainer.Chain):
|
|
|
+ if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
|
|
|
+ chain.enable_only_head()
|
|
|
+
|
|
|
+ else:
|
|
|
+ chain.disable_update()
|
|
|
+ chain.fc.enable_update()
|
|
|
+
|
|
|
|
|
|
class _ModelMixin(abc.ABC):
|
|
|
"""
|