123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302 |
- import numpy as np
- import pandas as pd
- from src.plotter import Plotter
- SRC_DIR = './results/'
- I_PRED_SRC_DIR = SRC_DIR + 'I_predictions/'
- SIM_DIR = './visualizations/'
- def get_error(y, y_ref):
- err = []
- for i in range(len(y)):
- diff = y[i] - y_ref
- err.append(np.linalg.norm(diff) / np.linalg.norm(y_ref))
- return np.array(err).mean(axis=0)
- STATE_LOOKUP = {'Schleswig_Holstein' : (79.5,0.0849),
- 'Hamburg' : (84.5, 0.0948),
- 'Niedersachsen' : (77.6, 0.0774),
- 'Bremen' : (88.3,0.0933),
- 'Nordrhein_Westfalen' : (79.5,0.0777),
- 'Hessen' : (75.8,0.1017),
- 'Rheinland_Pfalz' : (75.6,0.0895),
- 'Baden_Wuerttemberg' : (74.5,0.0796),
- 'Bayern' : (75.1,0.0952),
- 'Saarland' : (82.4,0.1080),
- 'Berlin' : (78.1,0.0667),
- 'Brandenburg' : (68.1,0.0724),
- 'Mecklenburg_Vorpommern' : (74.7,0.0540),
- 'Sachsen' : (65.1,0.1109),
- 'Sachsen_Anhalt' : (74.1,0.0785),
- 'Thueringen' : (70.3,0.0837),
- 'Germany' : (76.4, 0.0804)}
- state_names = ['Schleswig-Holstein',
- 'Hamburg',
- 'Lower Saxony',
- 'Bremen',
- 'North Rhine-Westphalia',
- 'Hesse',
- 'Rhineland-Palatinate',
- 'Baden-Württemberg',
- 'Bavaria',
- 'Saarland',
- 'Berlin',
- 'Brandenburg',
- 'Mecklenburg-Western Pomerania',
- 'Saxony',
- 'Saxony-Anhalt',
- 'Thuringia',
- 'Germany']
- plotter = Plotter(additional_colors=['yellow', 'cyan', 'magenta', ])
- # plot results for alpha and beta
- print("Visualizing Alpha and Beta results")
- # synth
- param_matrix = np.genfromtxt(SRC_DIR + f'synthetic_parameters.csv', delimiter=',')
- mean = param_matrix.mean(axis=0)
- std = param_matrix.std(axis=0)
- print("States Table form:")
- print('{0:.4f}'.format(1/3), "&", '{0:.4f}'.format(mean[0]), "&", '{0:.4f}'.format(std[0]), "&", '{0:.4f}'.format(1/2), "&", '{0:.4f}'.format(mean[1]), "&", '{0:.4f}'.format(std[1]), "\\\ ")
- plotter.scatter(np.arange(1, 6, 1), [param_matrix[:,0], param_matrix[:,1]], [r"$\alpha$", r"$\beta$"], (7,3.5), 'reproducability', '', true_values=[1/3, 1/2], xlabel='iteration')
- vaccination_ratios = []
- mean_std_parameters = {}
- for state in STATE_LOOKUP.keys():
- state_matrix = np.genfromtxt(SRC_DIR + f'{state}_parameters.csv', delimiter=',')
- mean = state_matrix.mean(axis=0)
- std = state_matrix.std(axis=0)
- mean_std_parameters.update({state : (mean, std)})
- vaccination_ratios.append(STATE_LOOKUP[state][0])
- values = np.array(list(mean_std_parameters.values()))
- means = values[:,0]
- stds = values[:,1]
- alpha_means = means[:,0]
- beta_means = means[:,1]
- alpha_stds = stds[:,0]
- beta_stds = stds[:,1]
- print(f"Vaccination corr: {np.corrcoef(beta_means, vaccination_ratios)[0, 1]}")
- vaccination_ratios = vaccination_ratios[:-1]
- sn = np.array(state_names[:-1]).copy()
- sn[12] = "MWP"
- plotter.scatter(sn,
- [alpha_means[:-1], beta_means[:-1]],
- [r'$\alpha$', r'$\beta$', ],
- (12, 6),
- 'mean_std_alpha_beta_res',
- '',
- std=[alpha_stds[:-1], beta_stds[:-1]],
- true_values=[alpha_means[-1], beta_means[-1]],
- true_label='Germany',
- xlabel_rotation=60,
- plot_legend=True,
- legend_loc="lower right")
- print("States Table form:")
- for i, state in enumerate(STATE_LOOKUP.keys()):
- print(state_names[i], "&", '{0:.3f}'.format(alpha_means[i]), "{\\tiny $\\pm",'{0:.3f}'.format(alpha_stds[i]), "$}", "&", '{0:.3f}'.format(beta_means[i]), "{\\tiny $\\pm", '{0:.3f}'.format(beta_stds[i]), "$}", "&", STATE_LOOKUP[state][1], "&", '{0:.3f}'.format(beta_means[i]-beta_means[16]), "&", '{0:.1f}'.format(STATE_LOOKUP[state][0]), "\\\ ")
- print()
- # plot results for reproduction number
- # synth
- synth_iterations = []
- for i in range(10):
- synth_iterations.append(np.genfromtxt(SRC_DIR + f'synthetic_{i}.csv', delimiter=','))
- synth_matrix = np.array(synth_iterations)
- t = np.arange(0, len(synth_matrix[0]), 1)
- synth_r_t = np.zeros(150, dtype=np.float64)
- for i, time in enumerate(range(150)):
- synth_r_t[i] = -np.tanh(time * 0.05 - 2) * 0.4 + 1.35
- print(f"Synthetic error R_t: {get_error(synth_matrix.mean(axis=0), synth_r_t)}")
- plotter.plot(t,
- [synth_matrix.mean(axis=0), synth_r_t],
- [r'$\mathcal{R}_t$', r'true $\mathcal{R}_t$'],
- f"synthetic_R_t_statistics",
- r"Synthetic data $\mathcal{R}_t$",
- (9, 6),
- fill_between=[synth_matrix.std(axis=0)],
- xlabel="time / days")
- pred_synth = np.genfromtxt(I_PRED_SRC_DIR + f'synthetic_0_I_prediction.csv', delimiter=',')
- print(f"Synthetic error I: {get_error(pred_synth[2], pred_synth[1])}")
- plotter.plot(pred_synth[0],
- [pred_synth[2], pred_synth[1]],
- [r'prediction $I$', r'true $I$'],
- f"synthetic_I_prediction",
- r"Synthetic data $I$ prediction",
- (9, 6),
- xlabel="time / days",
- ylabel='amount of people')
- EVENT_LOOKUP = {'start of vaccination' : 455,
- 'alpha variant' : 357,
- 'delta variant' : 473,
- 'omicron variant' : 663}
- ALPHA = [1 / 14, 1 / 5]
- cluster_counter = 1
- cluster_idx = 0
- in_text_r_t_mean = []
- in_text_r_t_std = []
- in_text_I = []
- in_text_I_std = []
- cluster_r_t_mean = []
- cluster_r_t_std = []
- cluster_I = []
- cluster_I_std = []
- cluster_states = []
- for k, state in enumerate(STATE_LOOKUP.keys()):
- if state == "Thueringen":
- l = 1
- elif state == "Bremen":
- l = 0
- # data fetch arrays
- r_t = []
- pred_i = []
- true_i = []
- cluster_states.append(state_names[k])
- cluster_r_t_mean.append([])
- cluster_r_t_std.append([])
- cluster_I.append([])
- cluster_I_std.append([np.zeros(1200), np.zeros(1200)])
- if state == "Thueringen" or state == "Bremen":
- in_text_r_t_mean.append([])
- in_text_r_t_std.append([])
- in_text_I.append([])
- in_text_I_std.append([np.zeros(1200), np.zeros(1200)])
- for i, alpha in enumerate(ALPHA):
- iterations = []
- predictions = []
- true = []
- for j in range(10):
- iterations.append(np.genfromtxt(SRC_DIR + f'{state}_{i}_{j}.csv', delimiter=','))
- if (k >= 3 and j == 3) or j > 3:
- data = np.genfromtxt(I_PRED_SRC_DIR + f'{state}_{i}_{j}_I_prediction.csv', delimiter=',')
- predictions.append(data[2])
- true = data[1]
- iterations = np.array(iterations)
- r_t.append(iterations)
- predictions = np.array(predictions)
- pred_i.append(predictions)
- true_i.append(true)
- cluster_r_t_mean[cluster_counter-1].append(iterations.mean(axis=0))
- cluster_r_t_std[cluster_counter-1].append(iterations.std(axis=0))
- if state == "Thueringen" or state == "Bremen":
- in_text_r_t_mean[l].append(iterations.mean(axis=0))
- in_text_r_t_std[l].append(iterations.std(axis=0))
- if state == "Thueringen" or state == "Bremen":
- in_text_I[l].append(true_i[0])
- in_text_I[l].append(true_i[1])
- in_text_I[l].append(pred_i[0].mean(axis=0))
- in_text_I[l].append(pred_i[1].mean(axis=0))
- in_text_I_std[l].append(pred_i[0].std(axis=0))
- in_text_I_std[l].append(pred_i[1].std(axis=0))
- cluster_I[cluster_counter-1].append(true_i[0])
- cluster_I[cluster_counter-1].append(true_i[1])
- cluster_I[cluster_counter-1].append(pred_i[0].mean(axis=0))
- cluster_I[cluster_counter-1].append(pred_i[1].mean(axis=0))
- cluster_I_std[cluster_counter-1].append(pred_i[0].std(axis=0))
- cluster_I_std[cluster_counter-1].append(pred_i[1].std(axis=0))
- # plot
- print(f"{state_names[k]} & {'{0:.3f}'.format(get_error(pred_i[0], true_i[0]))} & {'{0:.3f}'.format(get_error(pred_i[1], true_i[1]))} & \phantom{{0}} & {(r_t[0] > 1).sum(axis=1).mean()} & {(r_t[1] > 1).sum(axis=1).mean()} & {'{0:.3f}'.format(r_t[0].max(axis=1).mean())} & {'{0:.3f}'.format(r_t[1].max(axis=1).mean())}\\\ ")
- if len(cluster_states) == 4 and state != "Thueringen" or len(cluster_states) == 5 and state == "Germany":
- t = np.arange(0, 1200, 1)
- if len(cluster_states) == 5:
- y_lim_exception = 4
- else:
- y_lim_exception = None
- plotter.cluster_plot(t,
- cluster_r_t_mean,
- [r"$\alpha=\frac{1}{14}$", r"$\alpha=\frac{1}{5}$"],
- (len(cluster_states), 1),
- (9, 6),
- f'r_t_cluster_{cluster_idx}',
- [state + r" $\mathcal{R}_t$" for state in cluster_states],
- fill_between=cluster_r_t_std,
- event_lookup=EVENT_LOOKUP,
- xlabel='time / days',
- ylim=(0.3, 2.0),
- legend_loc=(0.53, 0.992),
- number_of_legend_columns=3)
- plotter.cluster_plot(t,
- cluster_I,
- [r"true $I$ $\alpha=\frac{1}{14}$",
- r"true $I$ $\alpha=\frac{1}{5}$",
- r"prediction $I$ $\alpha=\frac{1}{14}$",
- r"prediction $I$ $\alpha=\frac{1}{5}$"],
- (len(cluster_states), 1),
- (9, 6),
- f'I_cluster_{cluster_idx}',
- [state + r" $I$ prediction" for state in cluster_states],
- fill_between=cluster_I_std,
- xlabel='time / days',
- ylabel='amount of people',
- same_axes=False,
- ylim=(0, 600000),
- legend_loc=(0.55, 0.992),
- number_of_legend_columns=2,
- y_lim_exception=y_lim_exception)
- cluster_counter = 0
- cluster_idx += 1
- cluster_r_t_mean = []
- cluster_r_t_std = []
- cluster_I = []
- cluster_I_std = []
- cluster_states = []
- cluster_counter += 1
- plotter.cluster_plot(t,
- in_text_r_t_mean,
- [r"$\alpha=\frac{1}{14}$", r"$\alpha=\frac{1}{5}$"],
- (2, 1),
- (9, 6),
- f'r_t_cluster_intext',
- [state + r" $\mathcal{R}_t$" for state in ['Bremen', 'Thuringia']],
- fill_between=in_text_r_t_std,
- event_lookup=EVENT_LOOKUP,
- xlabel='time / days',
- ylim=(0.3, 2.0),
- legend_loc=(0.53, 0.999),
- add_y_space=0.08,
- number_of_legend_columns=3)
- plotter.cluster_plot(t,
- in_text_I,
- [r"true $I$ $\alpha=\frac{1}{14}$",
- r"true $I$ $\alpha=\frac{1}{5}$",
- r"prediction $I$ $\alpha=\frac{1}{14}$",
- r"prediction $I$ $\alpha=\frac{1}{5}$"],
- (2, 1),
- (9, 6),
- f'I_cluster_intext',
- [state + r" $I$ prediction" for state in ['Bremen', 'Thuringia']],
- fill_between=in_text_I_std,
- xlabel='time / days',
- ylabel='amount of people',
- ylim=(0, 600000),
- legend_loc=(0.55, 0.999),
- add_y_space=0.08,
- number_of_legend_columns=2)
|