plot_training_metrics.py 697 B

123456789101112131415161718192021
  1. import numpy as np
  2. from src.plotter import Plotter
  3. SRC_DIR = './results/'
  4. SIM_DIR = './visualizations/'
  5. MET_DIR = 'training_metrics/'
  6. plotter = Plotter()
  7. # plot synthetic loss
  8. physics_loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_physics_loss.csv', delimiter=',')
  9. obs_loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_obs_loss.csv', delimiter=',')
  10. loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_loss.csv', delimiter=',')
  11. t = np.arange(0, len(loss), 1)
  12. plotter.plot(t,
  13. [loss, physics_loss, obs_loss],
  14. ['loss', 'physics loss', 'data loss'],
  15. 'Germany_1_loss',
  16. 'Loss',
  17. (6, 6),
  18. y_log_scale=True)