Przeglądaj źródła

plot skripts for thesis

phillip.rothenbeck 4 miesięcy temu
rodzic
commit
eefda9b551
3 zmienionych plików z 432 dodań i 0 usunięć
  1. 109 0
      plot_datasets.py
  2. 302 0
      plot_results.py
  3. 21 0
      plot_training_metrics.py

+ 109 - 0
plot_datasets.py

@@ -0,0 +1,109 @@
+import numpy as np
+import pandas as pd
+from src.plotter import Plotter
+
+DS_DIR = './datasets/'
+SRC_DIR = './results/'
+SIM_DIR = './visualizations/'
+
+STATE_LOOKUP = {'Schleswig_Holstein' : 'Schleswig-Holstein',
+                'Hamburg' : 'Hamburg', 
+                'Niedersachsen' : 'Niedersachsen', 
+                'Bremen' : 'Bremen',
+                'Nordrhein_Westfalen' : 'North Rhine-Westphalia',
+                'Hessen' : 'Hessen',
+                'Rheinland_Pfalz' : 'Rhineland-Palatinate',
+                'Baden_Wuerttemberg' : 'Baden-Württemberg',
+                'Bayern' : 'Bavaria',
+                'Saarland' : 'Saarland',
+                'Berlin' : 'Berlin',
+                'Brandenburg' : 'Brandenburg',
+                'Mecklenburg_Vorpommern' : 'Mecklenburg-Western Pomerania',
+                'Sachsen' : 'Saxony',
+                'Sachsen_Anhalt' : 'Saxony-Anhalt',
+                'Thueringen' : 'Thuringia'}
+
+plotter = Plotter()
+
+data = []
+in_text_data = [np.genfromtxt(DS_DIR + f'SIR_data.csv', delimiter=',')[1:], 1]
+for state in STATE_LOOKUP.keys():
+    # print(f"plot {state}")
+    state_data_5 = np.genfromtxt(DS_DIR + f'I_RKI_{state}_1_5.csv', delimiter=',')[1]
+    state_data_14 = np.genfromtxt(DS_DIR + f'I_RKI_{state}_1_14.csv', delimiter=',')[1]
+    sir_data = np.genfromtxt(DS_DIR + f'SIR_RKI_{state}_1_14.csv', delimiter=',')[1:]
+    data.append(sir_data)
+    t = np.arange(0, 1200, 1)
+
+    if state in ["Schleswig_Holstein", "Berlin", "Thueringen"]:
+        in_text_data.append(sir_data)
+
+    plotter.plot(t, 
+                 [state_data_14, state_data_5], 
+                 [r'$\alpha=\frac{1}{14}$', r'$\alpha=\frac{1}{5}$'],
+                 f'{state}_datasets',
+                 f'{STATE_LOOKUP[state]}',
+                 (12,6),
+                 xlabel='time / days',
+                 ylabel='amount of people')
+    
+do_log=False
+plotter.cluster_plot(t,
+                     data[:6],
+                     [r'$S$', r'$I$', r'$R$'],
+                     (2, 3),
+                     (6,6),
+                     "state_sir_cluster_1",
+                     list(STATE_LOOKUP.values())[:6],
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     y_log_scale=do_log,
+                     add_y_space=0.05,
+                     number_of_legend_columns=3,
+                     same_axes=False,
+                     ylim=(0, 1.85e7))
+plotter.cluster_plot(t,
+                     data[7:],
+                     [r'$S$', r'$I$', r'$R$'],
+                     (3, 3),
+                     (6,6),
+                     "state_sir_cluster_2",
+                     list(STATE_LOOKUP.values())[7:],
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     y_log_scale=do_log,
+                     add_y_space=0.03,
+                     number_of_legend_columns=3,
+                     same_axes=False,
+                     ylim=(0, 1.85e7))
+
+germany_data = np.genfromtxt(DS_DIR + f'SIR_RKI_Germany_1_14.csv', delimiter=',')[1:]
+in_text_data[1] = germany_data
+
+plotter.plot(t,
+             germany_data,
+             [r'$S$', r'$I$', r'$R$'],
+             'germany_single_sir',
+             'Germany',
+             (6,6),
+             plot_legend=False,
+             xlabel='time / days',
+             ylabel='amount of people',)
+
+plotter.cluster_plot(t, 
+                     in_text_data,
+                     [r'$S$', r'$I$', r'$R$'],
+                     (2, 3),
+                     (6, 6),
+                     "in_text_SIR",
+                     ["synthetic SIR data", "Germany", 'Schleswig Holstein', 'Berlin', 'Thuringia'],
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     legend_loc=(0.51, 0.8),
+                     add_y_space=0,
+                     same_axes=False,
+                     free_axis=(0, 1),
+                     plot_all_labels=False)
+
+
+

+ 302 - 0
plot_results.py

@@ -0,0 +1,302 @@
+import numpy as np
+import pandas as pd
+from src.plotter import Plotter
+
+SRC_DIR = './results/'
+I_PRED_SRC_DIR = SRC_DIR + 'I_predictions/'
+SIM_DIR = './visualizations/'
+
+def get_error(y, y_ref):
+    err = []
+    for i in range(len(y)):
+        diff = y[i] - y_ref
+        err.append(np.linalg.norm(diff) / np.linalg.norm(y_ref))
+    return np.array(err).mean(axis=0)
+
+STATE_LOOKUP = {'Schleswig_Holstein' : (79.5,0.0849),
+                'Hamburg' : (84.5, 0.0948),
+                'Niedersachsen' : (77.6, 0.0774),
+                'Bremen' : (88.3,0.0933),
+                'Nordrhein_Westfalen' : (79.5,0.0777),
+                'Hessen' : (75.8,0.1017),
+                'Rheinland_Pfalz' : (75.6,0.0895),
+                'Baden_Wuerttemberg' : (74.5,0.0796),
+                'Bayern' : (75.1,0.0952),
+                'Saarland' : (82.4,0.1080),
+                'Berlin' : (78.1,0.0667),
+                'Brandenburg' : (68.1,0.0724),
+                'Mecklenburg_Vorpommern' : (74.7,0.0540),
+                'Sachsen' : (65.1,0.1109),
+                'Sachsen_Anhalt' : (74.1,0.0785),
+                'Thueringen' : (70.3,0.0837),
+                'Germany' : (76.4, 0.0804)}
+
+state_names = ['Schleswig-Holstein',
+                'Hamburg', 
+                'Lower Saxony', 
+                'Bremen',
+                'North Rhine-Westphalia',
+                'Hesse',
+                'Rhineland-Palatinate',
+                'Baden-Württemberg',
+                'Bavaria',
+                'Saarland',
+                'Berlin',
+                'Brandenburg',
+                'Mecklenburg-Western Pomerania',
+                'Saxony',
+                'Saxony-Anhalt',
+                'Thuringia', 
+                'Germany']
+
+plotter = Plotter(additional_colors=['yellow', 'cyan', 'magenta', ])
+
+# plot results for alpha and beta
+
+print("Visualizing Alpha and Beta results")
+
+# synth
+param_matrix = np.genfromtxt(SRC_DIR + f'synthetic_parameters.csv', delimiter=',')
+mean = param_matrix.mean(axis=0)
+std = param_matrix.std(axis=0)
+
+print("States Table form:")
+print('{0:.4f}'.format(1/3), "&", '{0:.4f}'.format(mean[0]), "&", '{0:.4f}'.format(std[0]), "&", '{0:.4f}'.format(1/2), "&", '{0:.4f}'.format(mean[1]), "&", '{0:.4f}'.format(std[1]), "\\\ ")
+
+plotter.scatter(np.arange(1, 6, 1), [param_matrix[:,0], param_matrix[:,1]], [r"$\alpha$", r"$\beta$"], (7,3.5), 'reproducability', '', true_values=[1/3, 1/2], xlabel='iteration')
+
+vaccination_ratios = []
+mean_std_parameters = {}
+for state in STATE_LOOKUP.keys():
+    state_matrix = np.genfromtxt(SRC_DIR + f'{state}_parameters.csv', delimiter=',')
+    mean = state_matrix.mean(axis=0)
+    std = state_matrix.std(axis=0)
+    mean_std_parameters.update({state : (mean, std)})
+    vaccination_ratios.append(STATE_LOOKUP[state][0])
+
+values = np.array(list(mean_std_parameters.values()))
+means = values[:,0]
+stds = values[:,1]
+alpha_means = means[:,0]
+beta_means = means[:,1]
+alpha_stds = stds[:,0]
+beta_stds = stds[:,1]
+
+print(f"Vaccination corr: {np.corrcoef(beta_means, vaccination_ratios)[0, 1]}")
+vaccination_ratios = vaccination_ratios[:-1]
+sn = np.array(state_names[:-1]).copy()
+sn[12] = "MWP"
+plotter.scatter(sn, 
+                [alpha_means[:-1], beta_means[:-1]], 
+                [r'$\alpha$', r'$\beta$', ],
+                (12, 6), 
+                'mean_std_alpha_beta_res',
+                '',
+                std=[alpha_stds[:-1], beta_stds[:-1]],
+                true_values=[alpha_means[-1], beta_means[-1]],
+                true_label='Germany',
+                xlabel_rotation=60,
+                plot_legend=True,
+                legend_loc="lower right")
+
+print("States Table form:")
+for i, state in enumerate(STATE_LOOKUP.keys()):
+    print(state_names[i], "&", '{0:.3f}'.format(alpha_means[i]), "{\\tiny $\\pm",'{0:.3f}'.format(alpha_stds[i]), "$}", "&", '{0:.3f}'.format(beta_means[i]), "{\\tiny $\\pm", '{0:.3f}'.format(beta_stds[i]), "$}", "&", STATE_LOOKUP[state][1], "&", '{0:.3f}'.format(beta_means[i]-beta_means[16]), "&", '{0:.1f}'.format(STATE_LOOKUP[state][0]), "\\\ ")
+
+print()
+# plot results for reproduction number
+
+# synth
+synth_iterations = []
+for i in range(10):   
+    synth_iterations.append(np.genfromtxt(SRC_DIR + f'synthetic_{i}.csv', delimiter=','))
+
+synth_matrix = np.array(synth_iterations)
+t = np.arange(0, len(synth_matrix[0]), 1)
+synth_r_t = np.zeros(150, dtype=np.float64)
+for i, time in enumerate(range(150)):
+    synth_r_t[i] = -np.tanh(time * 0.05 - 2) * 0.4 + 1.35
+
+print(f"Synthetic error R_t: {get_error(synth_matrix.mean(axis=0), synth_r_t)}")
+plotter.plot(t, 
+             [synth_matrix.mean(axis=0), synth_r_t],
+             [r'$\mathcal{R}_t$', r'true $\mathcal{R}_t$'], 
+             f"synthetic_R_t_statistics", 
+             r"Synthetic data $\mathcal{R}_t$", 
+             (9, 6),
+             fill_between=[synth_matrix.std(axis=0)],
+             xlabel="time / days")
+
+pred_synth = np.genfromtxt(I_PRED_SRC_DIR + f'synthetic_0_I_prediction.csv', delimiter=',')
+
+print(f"Synthetic error I: {get_error(pred_synth[2], pred_synth[1])}")
+
+plotter.plot(pred_synth[0], 
+             [pred_synth[2], pred_synth[1]],
+             [r'prediction $I$', r'true $I$'], 
+             f"synthetic_I_prediction", 
+             r"Synthetic data $I$ prediction", 
+             (9, 6),
+             xlabel="time / days",
+             ylabel='amount of people')
+
+EVENT_LOOKUP = {'start of vaccination' : 455,
+                'alpha variant' : 357,
+                'delta variant' : 473,
+                'omicron variant' : 663}
+
+
+ALPHA = [1 / 14, 1 / 5]
+
+
+cluster_counter = 1
+cluster_idx = 0
+
+in_text_r_t_mean = []
+in_text_r_t_std = []
+in_text_I = []
+in_text_I_std = []
+
+cluster_r_t_mean = []
+cluster_r_t_std = []
+cluster_I = []
+cluster_I_std = []
+cluster_states = []
+
+for k, state in enumerate(STATE_LOOKUP.keys()):
+
+    if state == "Thueringen":
+        l = 1
+    elif state == "Bremen":
+        l = 0
+    # data fetch arrays
+    r_t = []
+    pred_i = []
+    true_i = []
+    cluster_states.append(state_names[k])
+    cluster_r_t_mean.append([])
+    cluster_r_t_std.append([])
+    cluster_I.append([])
+    cluster_I_std.append([np.zeros(1200), np.zeros(1200)])
+    if state == "Thueringen" or state == "Bremen":
+        in_text_r_t_mean.append([])
+        in_text_r_t_std.append([])   
+        in_text_I.append([])
+        in_text_I_std.append([np.zeros(1200), np.zeros(1200)])
+    for i, alpha in enumerate(ALPHA):
+        iterations = []
+        predictions = []
+        true = []
+        for j in range(10): 
+            iterations.append(np.genfromtxt(SRC_DIR + f'{state}_{i}_{j}.csv', delimiter=','))
+            if (k >= 3 and j == 3) or j > 3:
+                data = np.genfromtxt(I_PRED_SRC_DIR + f'{state}_{i}_{j}_I_prediction.csv', delimiter=',')
+                predictions.append(data[2])
+                true = data[1]
+        iterations = np.array(iterations)
+        r_t.append(iterations)
+
+        predictions = np.array(predictions)
+        pred_i.append(predictions)
+        true_i.append(true)
+
+        cluster_r_t_mean[cluster_counter-1].append(iterations.mean(axis=0))
+        cluster_r_t_std[cluster_counter-1].append(iterations.std(axis=0))
+        if state == "Thueringen" or state == "Bremen":
+            in_text_r_t_mean[l].append(iterations.mean(axis=0))
+            in_text_r_t_std[l].append(iterations.std(axis=0))
+    if state == "Thueringen" or state == "Bremen":
+        in_text_I[l].append(true_i[0])
+        in_text_I[l].append(true_i[1])
+        in_text_I[l].append(pred_i[0].mean(axis=0))
+        in_text_I[l].append(pred_i[1].mean(axis=0))
+        in_text_I_std[l].append(pred_i[0].std(axis=0))
+        in_text_I_std[l].append(pred_i[1].std(axis=0))
+
+    cluster_I[cluster_counter-1].append(true_i[0])
+    cluster_I[cluster_counter-1].append(true_i[1])
+    cluster_I[cluster_counter-1].append(pred_i[0].mean(axis=0))
+    cluster_I[cluster_counter-1].append(pred_i[1].mean(axis=0))
+    cluster_I_std[cluster_counter-1].append(pred_i[0].std(axis=0))
+    cluster_I_std[cluster_counter-1].append(pred_i[1].std(axis=0))
+
+    # plot
+    print(f"{state_names[k]} & {'{0:.3f}'.format(get_error(pred_i[0], true_i[0]))} & {'{0:.3f}'.format(get_error(pred_i[1], true_i[1]))} & \phantom{{0}} & {(r_t[0] > 1).sum(axis=1).mean()} & {(r_t[1] > 1).sum(axis=1).mean()} & {'{0:.3f}'.format(r_t[0].max(axis=1).mean())} & {'{0:.3f}'.format(r_t[1].max(axis=1).mean())}\\\ ")
+    if len(cluster_states) == 4 and state != "Thueringen" or len(cluster_states) == 5 and state == "Germany":
+        t = np.arange(0, 1200, 1)
+        if len(cluster_states) == 5:
+            y_lim_exception = 4
+        else:
+            y_lim_exception = None
+        plotter.cluster_plot(t,
+                            cluster_r_t_mean,
+                            [r"$\alpha=\frac{1}{14}$", r"$\alpha=\frac{1}{5}$"],
+                            (len(cluster_states), 1),
+                            (9, 6),
+                            f'r_t_cluster_{cluster_idx}',
+                            [state + r"  $\mathcal{R}_t$" for state in cluster_states],
+                            fill_between=cluster_r_t_std,
+                            event_lookup=EVENT_LOOKUP,
+                            xlabel='time / days',
+                            ylim=(0.3, 2.0),
+                            legend_loc=(0.53, 0.992),
+                            number_of_legend_columns=3)
+        plotter.cluster_plot(t,
+                             cluster_I,
+                             [r"true $I$ $\alpha=\frac{1}{14}$", 
+                             r"true $I$ $\alpha=\frac{1}{5}$", 
+                             r"prediction $I$ $\alpha=\frac{1}{14}$", 
+                             r"prediction $I$ $\alpha=\frac{1}{5}$"],
+                             (len(cluster_states), 1),
+                             (9, 6),
+                             f'I_cluster_{cluster_idx}',
+                             [state + r" $I$ prediction" for state in cluster_states],
+                             fill_between=cluster_I_std,
+                             xlabel='time / days',
+                             ylabel='amount of people',
+                             same_axes=False,
+                             ylim=(0, 600000),
+                             legend_loc=(0.55, 0.992),
+                             number_of_legend_columns=2,
+                             y_lim_exception=y_lim_exception)
+        cluster_counter = 0
+        cluster_idx += 1
+        cluster_r_t_mean = []
+        cluster_r_t_std = []
+        cluster_I = []
+        cluster_I_std = []
+        cluster_states = []
+    cluster_counter += 1
+
+plotter.cluster_plot(t,
+                    in_text_r_t_mean,
+                    [r"$\alpha=\frac{1}{14}$", r"$\alpha=\frac{1}{5}$"],
+                    (2, 1),
+                    (9, 6),
+                    f'r_t_cluster_intext',
+                    [state + r"  $\mathcal{R}_t$" for state in ['Bremen', 'Thuringia']],
+                    fill_between=in_text_r_t_std,
+                    event_lookup=EVENT_LOOKUP,
+                    xlabel='time / days',
+                    ylim=(0.3, 2.0),
+                    legend_loc=(0.53, 0.999),
+                    add_y_space=0.08,
+                    number_of_legend_columns=3)
+plotter.cluster_plot(t,
+                     in_text_I,
+                     [r"true $I$ $\alpha=\frac{1}{14}$", 
+                     r"true $I$ $\alpha=\frac{1}{5}$", 
+                     r"prediction $I$ $\alpha=\frac{1}{14}$", 
+                     r"prediction $I$ $\alpha=\frac{1}{5}$"],
+                     (2, 1),
+                     (9, 6),
+                     f'I_cluster_intext',
+                     [state + r" $I$ prediction" for state in ['Bremen', 'Thuringia']],
+                     fill_between=in_text_I_std,
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     ylim=(0, 600000),
+                     legend_loc=(0.55, 0.999),
+                     add_y_space=0.08,
+                     number_of_legend_columns=2)
+

+ 21 - 0
plot_training_metrics.py

@@ -0,0 +1,21 @@
+import numpy as np
+from src.plotter import Plotter
+
+SRC_DIR = './results/'
+SIM_DIR = './visualizations/'
+MET_DIR = 'training_metrics/'
+
+plotter = Plotter()
+# plot synthetic loss
+physics_loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_physics_loss.csv', delimiter=',')
+obs_loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_obs_loss.csv', delimiter=',')
+loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_loss.csv', delimiter=',')
+
+t = np.arange(0, len(loss), 1)
+plotter.plot(t, 
+             [loss, physics_loss, obs_loss], 
+             ['loss', 'physics loss', 'data loss'], 
+             'Germany_1_loss',
+             'Loss', 
+             (6, 6),
+             y_log_scale=True)