|
@@ -1,20 +1,19 @@
|
|
|
-from .sacred import SacredTrainer
|
|
|
-from .base import _is_adam, default_intervals
|
|
|
-
|
|
|
from chainer.training import extensions
|
|
|
|
|
|
+from cvfinetune.training.trainer import base
|
|
|
+
|
|
|
def observe_alpha(trainer):
|
|
|
model = trainer.updater.get_optimizer("main").target.model
|
|
|
return float(model.pool.alpha.array)
|
|
|
|
|
|
-class AlphaPoolingTrainer(SacredTrainer):
|
|
|
+class AlphaPoolingTrainer(base.Trainer):
|
|
|
|
|
|
@property
|
|
|
def model(self):
|
|
|
return self.updater.get_optimizer("main").target.model
|
|
|
|
|
|
def __init__(self, opts, updater, *args, **kwargs):
|
|
|
- super(AlphaPoolingTrainer, self).__init__(opts=opts, updater=updater, *args, **kwargs)
|
|
|
+ super().__init__(opts=opts, updater=updater, *args, **kwargs)
|
|
|
### Alternating training of CNN and FC layers (only for alpha-pooling) ###
|
|
|
if opts.switch_epochs:
|
|
|
self.extend(SwitchTrainables(
|
|
@@ -23,9 +22,9 @@ class AlphaPoolingTrainer(SacredTrainer):
|
|
|
pooling=self.model.pool))
|
|
|
|
|
|
def reportables(self, opts):
|
|
|
- print_values, plot_values = super(AlphaPoolingTrainer, self).reportables()
|
|
|
+ print_values, plot_values = super().reportables()
|
|
|
alpha_update_rule = self.model.pool.alpha.update_rule
|
|
|
- if _is_adam(opts):
|
|
|
+ if base._is_adam(opts):
|
|
|
# in case of Adam optimizer
|
|
|
alpha_update_rule.hyperparam.alpha *= opts.kappa
|
|
|
else:
|
|
@@ -33,7 +32,7 @@ class AlphaPoolingTrainer(SacredTrainer):
|
|
|
|
|
|
self.extend(
|
|
|
extensions.observe_value("alpha", observe_alpha),
|
|
|
- trigger=default_intervals.print)
|
|
|
+ trigger=base.default_intervals.print)
|
|
|
|
|
|
print_values.append("alpha")
|
|
|
plot_values["alpha"]= ["alpha"]
|