12345678910111213141516171819202122232425262728293031 |
- from finetune.training.trainer import Trainer
- class PartsTrainer(Trainer):
- def reportables(self):
- print_vals, plot_vals = super(PartsTrainer, self).reportables()
- print_vals.extend([
- "main/glob_accu", self.eval_name("main/glob_accu"),
- "main/part_accu", self.eval_name("main/part_accu"),
- "main/logL", self.eval_name("main/logL"),
- # "main/glob_loss", self.eval_name("main/glob_loss"),
- # "main/part_loss", self.eval_name("main/part_loss"),
- ])
- plot_vals["logL"] = ["main/logL", self.eval_name("main/logL")]
- plot_vals["accuracy"].extend([
- "main/part_accu", self.eval_name("main/part_accu"),
- "main/glob_accu", self.eval_name("main/glob_accu"),
- ])
- plot_vals["loss"].extend([
- "main/part_loss", self.eval_name("main/part_loss"),
- "main/glob_loss", self.eval_name("main/glob_loss"),
- ])
- return print_vals, plot_vals
|