|
@@ -103,7 +103,7 @@ class Trainer(T):
|
|
|
self.setup_snapshots(opts, clf.model, intervals.snapshot)
|
|
self.setup_snapshots(opts, clf.model, intervals.snapshot)
|
|
|
|
|
|
|
|
### Reports and Plots ###
|
|
### Reports and Plots ###
|
|
|
- print_values, plot_values = self.reportables()
|
|
|
|
|
|
|
+ print_values, plot_values = self.reportables(opts)
|
|
|
self.extend(extensions.PrintReport(print_values), trigger=intervals.print)
|
|
self.extend(extensions.PrintReport(print_values), trigger=intervals.print)
|
|
|
for name, values in plot_values.items():
|
|
for name, values in plot_values.items():
|
|
|
ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
|
|
ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
|
|
@@ -128,7 +128,7 @@ class Trainer(T):
|
|
|
|
|
|
|
|
return f"{self.evaluator.default_name}/{name}"
|
|
return f"{self.evaluator.default_name}/{name}"
|
|
|
|
|
|
|
|
- def reportables(self):
|
|
|
|
|
|
|
+ def reportables(self, opts):
|
|
|
|
|
|
|
|
print_values = [
|
|
print_values = [
|
|
|
"elapsed_time",
|
|
"elapsed_time",
|