plot_datasets.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import numpy as np
  2. import pandas as pd
  3. from src.plotter import Plotter
  4. DS_DIR = './datasets/'
  5. SRC_DIR = './results/'
  6. SIM_DIR = './visualizations/'
  7. STATE_LOOKUP = {'Schleswig_Holstein' : 'Schleswig-Holstein',
  8. 'Hamburg' : 'Hamburg',
  9. 'Niedersachsen' : 'Niedersachsen',
  10. 'Bremen' : 'Bremen',
  11. 'Nordrhein_Westfalen' : 'North Rhine-Westphalia',
  12. 'Hessen' : 'Hessen',
  13. 'Rheinland_Pfalz' : 'Rhineland-Palatinate',
  14. 'Baden_Wuerttemberg' : 'Baden-Württemberg',
  15. 'Bayern' : 'Bavaria',
  16. 'Saarland' : 'Saarland',
  17. 'Berlin' : 'Berlin',
  18. 'Brandenburg' : 'Brandenburg',
  19. 'Mecklenburg_Vorpommern' : 'Mecklenburg-Western Pomerania',
  20. 'Sachsen' : 'Saxony',
  21. 'Sachsen_Anhalt' : 'Saxony-Anhalt',
  22. 'Thueringen' : 'Thuringia'}
  23. plotter = Plotter()
  24. data = []
  25. in_text_data = [np.genfromtxt(DS_DIR + f'SIR_data.csv', delimiter=',')[1:], 1]
  26. for state in STATE_LOOKUP.keys():
  27. # print(f"plot {state}")
  28. state_data_5 = np.genfromtxt(DS_DIR + f'I_RKI_{state}_1_5.csv', delimiter=',')[1]
  29. state_data_14 = np.genfromtxt(DS_DIR + f'I_RKI_{state}_1_14.csv', delimiter=',')[1]
  30. sir_data = np.genfromtxt(DS_DIR + f'SIR_RKI_{state}_1_14.csv', delimiter=',')[1:]
  31. data.append(sir_data)
  32. t = np.arange(0, 1200, 1)
  33. if state in ["Schleswig_Holstein", "Berlin", "Thueringen"]:
  34. in_text_data.append(sir_data)
  35. plotter.plot(t,
  36. [state_data_14, state_data_5],
  37. [r'$\alpha=\frac{1}{14}$', r'$\alpha=\frac{1}{5}$'],
  38. f'{state}_datasets',
  39. f'{STATE_LOOKUP[state]}',
  40. (12,6),
  41. xlabel='time / days',
  42. ylabel='amount of people')
  43. do_log=False
  44. plotter.cluster_plot(t,
  45. data[:6],
  46. [r'$S$', r'$I$', r'$R$'],
  47. (2, 3),
  48. (6,6),
  49. "state_sir_cluster_1",
  50. list(STATE_LOOKUP.values())[:6],
  51. xlabel='time / days',
  52. ylabel='amount of people',
  53. y_log_scale=do_log,
  54. add_y_space=0.05,
  55. number_of_legend_columns=3,
  56. same_axes=False,
  57. ylim=(0, 1.85e7))
  58. plotter.cluster_plot(t,
  59. data[7:],
  60. [r'$S$', r'$I$', r'$R$'],
  61. (3, 3),
  62. (6,6),
  63. "state_sir_cluster_2",
  64. list(STATE_LOOKUP.values())[7:],
  65. xlabel='time / days',
  66. ylabel='amount of people',
  67. y_log_scale=do_log,
  68. add_y_space=0.03,
  69. number_of_legend_columns=3,
  70. same_axes=False,
  71. ylim=(0, 1.85e7))
  72. germany_data = np.genfromtxt(DS_DIR + f'SIR_RKI_Germany_1_14.csv', delimiter=',')[1:]
  73. in_text_data[1] = germany_data
  74. plotter.plot(t,
  75. germany_data,
  76. [r'$S$', r'$I$', r'$R$'],
  77. 'germany_single_sir',
  78. 'Germany',
  79. (6,6),
  80. plot_legend=False,
  81. xlabel='time / days',
  82. ylabel='amount of people',)
  83. plotter.cluster_plot(t,
  84. in_text_data,
  85. [r'$S$', r'$I$', r'$R$'],
  86. (2, 3),
  87. (6, 6),
  88. "in_text_SIR",
  89. ["synthetic SIR data", "Germany", 'Schleswig Holstein', 'Berlin', 'Thuringia'],
  90. xlabel='time / days',
  91. ylabel='amount of people',
  92. legend_loc=(0.51, 0.8),
  93. add_y_space=0,
  94. same_axes=False,
  95. free_axis=(0, 1),
  96. plot_all_labels=False)