trainer.py 859 B

12345678910111213141516171819202122232425262728293031
  1. from finetune.training.trainer import Trainer
  2. class PartsTrainer(Trainer):
  3. def reportables(self):
  4. print_vals, plot_vals = super(PartsTrainer, self).reportables()
  5. print_vals.extend([
  6. "main/glob_accu", self.eval_name("main/glob_accu"),
  7. "main/part_accu", self.eval_name("main/part_accu"),
  8. "main/logL", self.eval_name("main/logL"),
  9. # "main/glob_loss", self.eval_name("main/glob_loss"),
  10. # "main/part_loss", self.eval_name("main/part_loss"),
  11. ])
  12. plot_vals["logL"] = ["main/logL", self.eval_name("main/logL")]
  13. plot_vals["accuracy"].extend([
  14. "main/part_accu", self.eval_name("main/part_accu"),
  15. "main/glob_accu", self.eval_name("main/glob_accu"),
  16. ])
  17. plot_vals["loss"].extend([
  18. "main/part_loss", self.eval_name("main/part_loss"),
  19. "main/glob_loss", self.eval_name("main/glob_loss"),
  20. ])
  21. return print_vals, plot_vals