|
@@ -57,6 +57,7 @@ class Trainer(T):
|
|
)
|
|
)
|
|
|
|
|
|
### Evaluator ###
|
|
### Evaluator ###
|
|
|
|
+ self.evaluator = evaluator
|
|
if evaluator is not None:
|
|
if evaluator is not None:
|
|
self.extend(evaluator, trigger=intervals.eval)
|
|
self.extend(evaluator, trigger=intervals.eval)
|
|
|
|
|
|
@@ -99,7 +100,7 @@ class Trainer(T):
|
|
self.setup_snapshots(opts, clf.model, intervals.snapshot)
|
|
self.setup_snapshots(opts, clf.model, intervals.snapshot)
|
|
|
|
|
|
### Reports and Plots ###
|
|
### Reports and Plots ###
|
|
- print_values, plot_values = self.reportables(opts, model, evaluator)
|
|
|
|
|
|
+ print_values, plot_values = self.reportables()
|
|
self.extend(extensions.PrintReport(print_values), trigger=intervals.print)
|
|
self.extend(extensions.PrintReport(print_values), trigger=intervals.print)
|
|
for name, values in plot_values.items():
|
|
for name, values in plot_values.items():
|
|
ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
|
|
ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
|
|
@@ -118,27 +119,30 @@ class Trainer(T):
|
|
self.extend(extensions.snapshot_object(obj, dump_fmt), trigger=trigger)
|
|
self.extend(extensions.snapshot_object(obj, dump_fmt), trigger=trigger)
|
|
logging.info("Snapshot format: \"{}\"".format(dump_fmt))
|
|
logging.info("Snapshot format: \"{}\"".format(dump_fmt))
|
|
|
|
|
|
|
|
+ def eval_name(self, name):
|
|
|
|
+ if self.evaluator is None:
|
|
|
|
+ return name
|
|
|
|
|
|
- def reportables(self, opts, model, evaluator):
|
|
|
|
- eval_name = lambda name: f"{evaluator.default_name}/{name}"
|
|
|
|
|
|
+ return f"{self.evaluator.default_name}/{name}"
|
|
|
|
|
|
|
|
+ def reportables(self):
|
|
|
|
|
|
print_values = [
|
|
print_values = [
|
|
"elapsed_time",
|
|
"elapsed_time",
|
|
"epoch",
|
|
"epoch",
|
|
# "lr",
|
|
# "lr",
|
|
|
|
|
|
- "main/accuracy", eval_name("main/accuracy"),
|
|
|
|
- "main/loss", eval_name("main/loss"),
|
|
|
|
|
|
+ "main/accuracy", self.eval_name("main/accuracy"),
|
|
|
|
+ "main/loss", self.eval_name("main/loss"),
|
|
|
|
|
|
]
|
|
]
|
|
|
|
|
|
plot_values = {
|
|
plot_values = {
|
|
"accuracy": [
|
|
"accuracy": [
|
|
- "main/accuracy", eval_name("main/accuracy"),
|
|
|
|
|
|
+ "main/accuracy", self.eval_name("main/accuracy"),
|
|
],
|
|
],
|
|
"loss": [
|
|
"loss": [
|
|
- "main/loss", eval_name("main/loss"),
|
|
|
|
|
|
+ "main/loss", self.eval_name("main/loss"),
|
|
],
|
|
],
|
|
}
|
|
}
|
|
|
|
|
|
@@ -150,34 +154,6 @@ class Trainer(T):
|
|
# ]
|
|
# ]
|
|
# })
|
|
# })
|
|
|
|
|
|
- # if opts.use_parts:
|
|
|
|
- # print_values.extend(["main/logL", eval_name("main/logL")])
|
|
|
|
- # plot_values.update({
|
|
|
|
- # "logL": [
|
|
|
|
- # "main/logL", eval_name("main/logL"),
|
|
|
|
- # ]
|
|
|
|
- # })
|
|
|
|
-
|
|
|
|
- # if not opts.no_global:
|
|
|
|
- # print_values.extend([
|
|
|
|
- # "main/glob_accu", eval_name("main/glob_accu"),
|
|
|
|
- # # "main/glob_loss", eval_name("main/glob_loss"),
|
|
|
|
-
|
|
|
|
- # "main/part_accu", eval_name("main/part_accu"),
|
|
|
|
- # # "main/part_loss", eval_name("main/part_loss"),
|
|
|
|
- # ])
|
|
|
|
-
|
|
|
|
- # plot_values["accuracy"].extend([
|
|
|
|
- # "main/part_accu", eval_name("main/part_accu"),
|
|
|
|
- # "main/glob_accu", eval_name("main/glob_accu"),
|
|
|
|
- # ])
|
|
|
|
-
|
|
|
|
- # plot_values["loss"].extend([
|
|
|
|
- # "main/part_loss", eval_name("main/part_loss"),
|
|
|
|
- # "main/glob_loss", eval_name("main/glob_loss"),
|
|
|
|
- # ])
|
|
|
|
-
|
|
|
|
-
|
|
|
|
return print_values, plot_values
|
|
return print_values, plot_values
|
|
|
|
|
|
|
|
|
|
@@ -194,8 +170,7 @@ class Trainer(T):
|
|
def run(self, init_eval=True):
|
|
def run(self, init_eval=True):
|
|
if init_eval:
|
|
if init_eval:
|
|
logging.info("Evaluating initial model ...")
|
|
logging.info("Evaluating initial model ...")
|
|
- evaluator = self.get_extension("val")
|
|
|
|
- init_perf = evaluator(self)
|
|
|
|
|
|
+ init_perf = self.evaluator(self)
|
|
logging.info("Initial accuracy: {val/main/accuracy:.3%} initial loss: {val/main/loss:.3f}".format(
|
|
logging.info("Initial accuracy: {val/main/accuracy:.3%} initial loss: {val/main/loss:.3f}".format(
|
|
**{key: float(value) for key, value in init_perf.items()}
|
|
**{key: float(value) for key, value in init_perf.items()}
|
|
))
|
|
))
|
|
@@ -210,19 +185,22 @@ class SacredTrainer(Trainer):
|
|
|
|
|
|
class AlphaPoolingTrainer(SacredTrainer):
|
|
class AlphaPoolingTrainer(SacredTrainer):
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def model(self):
|
|
|
|
+ return self.updater.get_optimizer("main").target.model
|
|
|
|
+
|
|
def __init__(self, opts, updater, *args, **kwargs):
|
|
def __init__(self, opts, updater, *args, **kwargs):
|
|
super(AlphaPoolingTrainer, self).__init__(opts=opts, updater=updater, *args, **kwargs)
|
|
super(AlphaPoolingTrainer, self).__init__(opts=opts, updater=updater, *args, **kwargs)
|
|
- model = updater.get_optimizer("main").target.model
|
|
|
|
### Alternating training of CNN and FC layers (only for alpha-pooling) ###
|
|
### Alternating training of CNN and FC layers (only for alpha-pooling) ###
|
|
if opts.switch_epochs:
|
|
if opts.switch_epochs:
|
|
self.extend(SwitchTrainables(
|
|
self.extend(SwitchTrainables(
|
|
opts.switch_epochs,
|
|
opts.switch_epochs,
|
|
- model=model,
|
|
|
|
- pooling=model.pool))
|
|
|
|
|
|
+ model=self.model,
|
|
|
|
+ pooling=self.model.pool))
|
|
|
|
|
|
- def reportables(self, opts, model, evaluator):
|
|
|
|
- print_values, plot_values = super(AlphaPoolingTrainer, self).reportables(opts, model, evaluator)
|
|
|
|
- alpha_update_rule = model.pool.alpha.update_rule
|
|
|
|
|
|
+ def reportables(self):
|
|
|
|
+ print_values, plot_values = super(AlphaPoolingTrainer, self).reportables()
|
|
|
|
+ alpha_update_rule = self.model.pool.alpha.update_rule
|
|
if _is_adam(opts):
|
|
if _is_adam(opts):
|
|
# in case of Adam optimizer
|
|
# in case of Adam optimizer
|
|
alpha_update_rule.hyperparam.alpha *= opts.kappa
|
|
alpha_update_rule.hyperparam.alpha *= opts.kappa
|