plot_results.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. import numpy as np
  2. import pandas as pd
  3. from src.plotter import Plotter
  4. SRC_DIR = './results/'
  5. I_PRED_SRC_DIR = SRC_DIR + 'I_predictions/'
  6. SIM_DIR = './visualizations/'
  7. def get_error(y, y_ref):
  8. err = []
  9. for i in range(len(y)):
  10. diff = y[i] - y_ref
  11. err.append(np.linalg.norm(diff) / np.linalg.norm(y_ref))
  12. return np.array(err).mean(axis=0)
  13. STATE_LOOKUP = {'Schleswig_Holstein' : (79.5,0.0849),
  14. 'Hamburg' : (84.5, 0.0948),
  15. 'Niedersachsen' : (77.6, 0.0774),
  16. 'Bremen' : (88.3,0.0933),
  17. 'Nordrhein_Westfalen' : (79.5,0.0777),
  18. 'Hessen' : (75.8,0.1017),
  19. 'Rheinland_Pfalz' : (75.6,0.0895),
  20. 'Baden_Wuerttemberg' : (74.5,0.0796),
  21. 'Bayern' : (75.1,0.0952),
  22. 'Saarland' : (82.4,0.1080),
  23. 'Berlin' : (78.1,0.0667),
  24. 'Brandenburg' : (68.1,0.0724),
  25. 'Mecklenburg_Vorpommern' : (74.7,0.0540),
  26. 'Sachsen' : (65.1,0.1109),
  27. 'Sachsen_Anhalt' : (74.1,0.0785),
  28. 'Thueringen' : (70.3,0.0837),
  29. 'Germany' : (76.4, 0.0804)}
  30. state_names = ['Schleswig-Holstein',
  31. 'Hamburg',
  32. 'Lower Saxony',
  33. 'Bremen',
  34. 'North Rhine-Westphalia',
  35. 'Hesse',
  36. 'Rhineland-Palatinate',
  37. 'Baden-Württemberg',
  38. 'Bavaria',
  39. 'Saarland',
  40. 'Berlin',
  41. 'Brandenburg',
  42. 'Mecklenburg-Western Pomerania',
  43. 'Saxony',
  44. 'Saxony-Anhalt',
  45. 'Thuringia',
  46. 'Germany']
  47. plotter = Plotter(additional_colors=['yellow', 'cyan', 'magenta', ])
  48. # plot results for alpha and beta
  49. print("Visualizing Alpha and Beta results")
  50. # synth
  51. param_matrix = np.genfromtxt(SRC_DIR + f'synthetic_parameters.csv', delimiter=',')
  52. mean = param_matrix.mean(axis=0)
  53. std = param_matrix.std(axis=0)
  54. print("States Table form:")
  55. print('{0:.4f}'.format(1/3), "&", '{0:.4f}'.format(mean[0]), "&", '{0:.4f}'.format(std[0]), "&", '{0:.4f}'.format(1/2), "&", '{0:.4f}'.format(mean[1]), "&", '{0:.4f}'.format(std[1]), "\\\ ")
  56. plotter.scatter(np.arange(1, 6, 1), [param_matrix[:,0], param_matrix[:,1]], [r"$\alpha$", r"$\beta$"], (7,3.5), 'reproducability', '', true_values=[1/3, 1/2], xlabel='iteration')
  57. vaccination_ratios = []
  58. mean_std_parameters = {}
  59. for state in STATE_LOOKUP.keys():
  60. state_matrix = np.genfromtxt(SRC_DIR + f'{state}_parameters.csv', delimiter=',')
  61. mean = state_matrix.mean(axis=0)
  62. std = state_matrix.std(axis=0)
  63. mean_std_parameters.update({state : (mean, std)})
  64. vaccination_ratios.append(STATE_LOOKUP[state][0])
  65. values = np.array(list(mean_std_parameters.values()))
  66. means = values[:,0]
  67. stds = values[:,1]
  68. alpha_means = means[:,0]
  69. beta_means = means[:,1]
  70. alpha_stds = stds[:,0]
  71. beta_stds = stds[:,1]
  72. print(f"Vaccination corr: {np.corrcoef(beta_means, vaccination_ratios)[0, 1]}")
  73. vaccination_ratios = vaccination_ratios[:-1]
  74. sn = np.array(state_names[:-1]).copy()
  75. sn[12] = "MWP"
  76. plotter.scatter(sn,
  77. [alpha_means[:-1], beta_means[:-1]],
  78. [r'$\alpha$', r'$\beta$', ],
  79. (12, 6),
  80. 'mean_std_alpha_beta_res',
  81. '',
  82. std=[alpha_stds[:-1], beta_stds[:-1]],
  83. true_values=[alpha_means[-1], beta_means[-1]],
  84. true_label='Germany',
  85. xlabel_rotation=60,
  86. plot_legend=True,
  87. legend_loc="lower right")
  88. print("States Table form:")
  89. for i, state in enumerate(STATE_LOOKUP.keys()):
  90. print(state_names[i], "&", '{0:.3f}'.format(alpha_means[i]), "{\\tiny $\\pm",'{0:.3f}'.format(alpha_stds[i]), "$}", "&", '{0:.3f}'.format(beta_means[i]), "{\\tiny $\\pm", '{0:.3f}'.format(beta_stds[i]), "$}", "&", STATE_LOOKUP[state][1], "&", '{0:.3f}'.format(beta_means[i]-beta_means[16]), "&", '{0:.1f}'.format(STATE_LOOKUP[state][0]), "\\\ ")
  91. print()
  92. # plot results for reproduction number
  93. # synth
  94. synth_iterations = []
  95. for i in range(10):
  96. synth_iterations.append(np.genfromtxt(SRC_DIR + f'synthetic_{i}.csv', delimiter=','))
  97. synth_matrix = np.array(synth_iterations)
  98. t = np.arange(0, len(synth_matrix[0]), 1)
  99. synth_r_t = np.zeros(150, dtype=np.float64)
  100. for i, time in enumerate(range(150)):
  101. synth_r_t[i] = -np.tanh(time * 0.05 - 2) * 0.4 + 1.35
  102. print(f"Synthetic error R_t: {get_error(synth_matrix.mean(axis=0), synth_r_t)}")
  103. plotter.plot(t,
  104. [synth_matrix.mean(axis=0), synth_r_t],
  105. [r'$\mathcal{R}_t$', r'true $\mathcal{R}_t$'],
  106. f"synthetic_R_t_statistics",
  107. r"Synthetic data $\mathcal{R}_t$",
  108. (9, 6),
  109. fill_between=[synth_matrix.std(axis=0)],
  110. xlabel="time / days")
  111. pred_synth = np.genfromtxt(I_PRED_SRC_DIR + f'synthetic_0_I_prediction.csv', delimiter=',')
  112. print(f"Synthetic error I: {get_error(pred_synth[2], pred_synth[1])}")
  113. plotter.plot(pred_synth[0],
  114. [pred_synth[2], pred_synth[1]],
  115. [r'prediction $I$', r'true $I$'],
  116. f"synthetic_I_prediction",
  117. r"Synthetic data $I$ prediction",
  118. (9, 6),
  119. xlabel="time / days",
  120. ylabel='amount of people')
  121. EVENT_LOOKUP = {'start of vaccination' : 455,
  122. 'alpha variant' : 357,
  123. 'delta variant' : 473,
  124. 'omicron variant' : 663}
  125. ALPHA = [1 / 14, 1 / 5]
  126. cluster_counter = 1
  127. cluster_idx = 0
  128. in_text_r_t_mean = []
  129. in_text_r_t_std = []
  130. in_text_I = []
  131. in_text_I_std = []
  132. cluster_r_t_mean = []
  133. cluster_r_t_std = []
  134. cluster_I = []
  135. cluster_I_std = []
  136. cluster_states = []
  137. for k, state in enumerate(STATE_LOOKUP.keys()):
  138. if state == "Thueringen":
  139. l = 1
  140. elif state == "Bremen":
  141. l = 0
  142. # data fetch arrays
  143. r_t = []
  144. pred_i = []
  145. true_i = []
  146. cluster_states.append(state_names[k])
  147. cluster_r_t_mean.append([])
  148. cluster_r_t_std.append([])
  149. cluster_I.append([])
  150. cluster_I_std.append([np.zeros(1200), np.zeros(1200)])
  151. if state == "Thueringen" or state == "Bremen":
  152. in_text_r_t_mean.append([])
  153. in_text_r_t_std.append([])
  154. in_text_I.append([])
  155. in_text_I_std.append([np.zeros(1200), np.zeros(1200)])
  156. for i, alpha in enumerate(ALPHA):
  157. iterations = []
  158. predictions = []
  159. true = []
  160. for j in range(10):
  161. iterations.append(np.genfromtxt(SRC_DIR + f'{state}_{i}_{j}.csv', delimiter=','))
  162. if (k >= 3 and j == 3) or j > 3:
  163. data = np.genfromtxt(I_PRED_SRC_DIR + f'{state}_{i}_{j}_I_prediction.csv', delimiter=',')
  164. predictions.append(data[2])
  165. true = data[1]
  166. iterations = np.array(iterations)
  167. r_t.append(iterations)
  168. predictions = np.array(predictions)
  169. pred_i.append(predictions)
  170. true_i.append(true)
  171. cluster_r_t_mean[cluster_counter-1].append(iterations.mean(axis=0))
  172. cluster_r_t_std[cluster_counter-1].append(iterations.std(axis=0))
  173. if state == "Thueringen" or state == "Bremen":
  174. in_text_r_t_mean[l].append(iterations.mean(axis=0))
  175. in_text_r_t_std[l].append(iterations.std(axis=0))
  176. if state == "Thueringen" or state == "Bremen":
  177. in_text_I[l].append(true_i[0])
  178. in_text_I[l].append(true_i[1])
  179. in_text_I[l].append(pred_i[0].mean(axis=0))
  180. in_text_I[l].append(pred_i[1].mean(axis=0))
  181. in_text_I_std[l].append(pred_i[0].std(axis=0))
  182. in_text_I_std[l].append(pred_i[1].std(axis=0))
  183. cluster_I[cluster_counter-1].append(true_i[0])
  184. cluster_I[cluster_counter-1].append(true_i[1])
  185. cluster_I[cluster_counter-1].append(pred_i[0].mean(axis=0))
  186. cluster_I[cluster_counter-1].append(pred_i[1].mean(axis=0))
  187. cluster_I_std[cluster_counter-1].append(pred_i[0].std(axis=0))
  188. cluster_I_std[cluster_counter-1].append(pred_i[1].std(axis=0))
  189. # plot
  190. print(f"{state_names[k]} & {'{0:.3f}'.format(get_error(pred_i[0], true_i[0]))} & {'{0:.3f}'.format(get_error(pred_i[1], true_i[1]))} & \phantom{{0}} & {(r_t[0] > 1).sum(axis=1).mean()} & {(r_t[1] > 1).sum(axis=1).mean()} & {'{0:.3f}'.format(r_t[0].max(axis=1).mean())} & {'{0:.3f}'.format(r_t[1].max(axis=1).mean())}\\\ ")
  191. if len(cluster_states) == 4 and state != "Thueringen" or len(cluster_states) == 5 and state == "Germany":
  192. t = np.arange(0, 1200, 1)
  193. if len(cluster_states) == 5:
  194. y_lim_exception = 4
  195. else:
  196. y_lim_exception = None
  197. plotter.cluster_plot(t,
  198. cluster_r_t_mean,
  199. [r"$\alpha=\frac{1}{14}$", r"$\alpha=\frac{1}{5}$"],
  200. (len(cluster_states), 1),
  201. (9, 6),
  202. f'r_t_cluster_{cluster_idx}',
  203. [state + r" $\mathcal{R}_t$" for state in cluster_states],
  204. fill_between=cluster_r_t_std,
  205. event_lookup=EVENT_LOOKUP,
  206. xlabel='time / days',
  207. ylim=(0.3, 2.0),
  208. legend_loc=(0.53, 0.992),
  209. number_of_legend_columns=3)
  210. plotter.cluster_plot(t,
  211. cluster_I,
  212. [r"true $I$ $\alpha=\frac{1}{14}$",
  213. r"true $I$ $\alpha=\frac{1}{5}$",
  214. r"prediction $I$ $\alpha=\frac{1}{14}$",
  215. r"prediction $I$ $\alpha=\frac{1}{5}$"],
  216. (len(cluster_states), 1),
  217. (9, 6),
  218. f'I_cluster_{cluster_idx}',
  219. [state + r" $I$ prediction" for state in cluster_states],
  220. fill_between=cluster_I_std,
  221. xlabel='time / days',
  222. ylabel='amount of people',
  223. same_axes=False,
  224. ylim=(0, 600000),
  225. legend_loc=(0.55, 0.992),
  226. number_of_legend_columns=2,
  227. y_lim_exception=y_lim_exception)
  228. cluster_counter = 0
  229. cluster_idx += 1
  230. cluster_r_t_mean = []
  231. cluster_r_t_std = []
  232. cluster_I = []
  233. cluster_I_std = []
  234. cluster_states = []
  235. cluster_counter += 1
  236. plotter.cluster_plot(t,
  237. in_text_r_t_mean,
  238. [r"$\alpha=\frac{1}{14}$", r"$\alpha=\frac{1}{5}$"],
  239. (2, 1),
  240. (9, 6),
  241. f'r_t_cluster_intext',
  242. [state + r" $\mathcal{R}_t$" for state in ['Bremen', 'Thuringia']],
  243. fill_between=in_text_r_t_std,
  244. event_lookup=EVENT_LOOKUP,
  245. xlabel='time / days',
  246. ylim=(0.3, 2.0),
  247. legend_loc=(0.53, 0.999),
  248. add_y_space=0.08,
  249. number_of_legend_columns=3)
  250. plotter.cluster_plot(t,
  251. in_text_I,
  252. [r"true $I$ $\alpha=\frac{1}{14}$",
  253. r"true $I$ $\alpha=\frac{1}{5}$",
  254. r"prediction $I$ $\alpha=\frac{1}{14}$",
  255. r"prediction $I$ $\alpha=\frac{1}{5}$"],
  256. (2, 1),
  257. (9, 6),
  258. f'I_cluster_intext',
  259. [state + r" $I$ prediction" for state in ['Bremen', 'Thuringia']],
  260. fill_between=in_text_I_std,
  261. xlabel='time / days',
  262. ylabel='amount of people',
  263. ylim=(0, 600000),
  264. legend_loc=(0.55, 0.999),
  265. add_y_space=0.08,
  266. number_of_legend_columns=2)