123456789101112131415161718192021 |
- import numpy as np
- from src.plotter import Plotter
- SRC_DIR = './results/'
- SIM_DIR = './visualizations/'
- MET_DIR = 'training_metrics/'
- plotter = Plotter()
- # plot synthetic loss
- physics_loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_physics_loss.csv', delimiter=',')
- obs_loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_obs_loss.csv', delimiter=',')
- loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_loss.csv', delimiter=',')
- t = np.arange(0, len(loss), 1)
- plotter.plot(t,
- [loss, physics_loss, obs_loss],
- ['loss', 'physics loss', 'data loss'],
- 'Germany_1_loss',
- 'Loss',
- (6, 6),
- y_log_scale=True)
|