123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import numpy as np
- import pandas as pd
- from src.plotter import Plotter
- DS_DIR = './datasets/'
- SRC_DIR = './results/'
- SIM_DIR = './visualizations/'
- STATE_LOOKUP = {'Schleswig_Holstein' : 'Schleswig-Holstein',
- 'Hamburg' : 'Hamburg',
- 'Niedersachsen' : 'Niedersachsen',
- 'Bremen' : 'Bremen',
- 'Nordrhein_Westfalen' : 'North Rhine-Westphalia',
- 'Hessen' : 'Hessen',
- 'Rheinland_Pfalz' : 'Rhineland-Palatinate',
- 'Baden_Wuerttemberg' : 'Baden-Württemberg',
- 'Bayern' : 'Bavaria',
- 'Saarland' : 'Saarland',
- 'Berlin' : 'Berlin',
- 'Brandenburg' : 'Brandenburg',
- 'Mecklenburg_Vorpommern' : 'Mecklenburg-Western Pomerania',
- 'Sachsen' : 'Saxony',
- 'Sachsen_Anhalt' : 'Saxony-Anhalt',
- 'Thueringen' : 'Thuringia'}
- plotter = Plotter()
- data = []
- in_text_data = [np.genfromtxt(DS_DIR + f'SIR_data.csv', delimiter=',')[1:], 1]
- for state in STATE_LOOKUP.keys():
- # print(f"plot {state}")
- state_data_5 = np.genfromtxt(DS_DIR + f'I_RKI_{state}_1_5.csv', delimiter=',')[1]
- state_data_14 = np.genfromtxt(DS_DIR + f'I_RKI_{state}_1_14.csv', delimiter=',')[1]
- sir_data = np.genfromtxt(DS_DIR + f'SIR_RKI_{state}_1_14.csv', delimiter=',')[1:]
- data.append(sir_data)
- t = np.arange(0, 1200, 1)
- if state in ["Schleswig_Holstein", "Berlin", "Thueringen"]:
- in_text_data.append(sir_data)
- plotter.plot(t,
- [state_data_14, state_data_5],
- [r'$\alpha=\frac{1}{14}$', r'$\alpha=\frac{1}{5}$'],
- f'{state}_datasets',
- f'{STATE_LOOKUP[state]}',
- (12,6),
- xlabel='time / days',
- ylabel='amount of people')
-
- do_log=False
- plotter.cluster_plot(t,
- data[:6],
- [r'$S$', r'$I$', r'$R$'],
- (2, 3),
- (6,6),
- "state_sir_cluster_1",
- list(STATE_LOOKUP.values())[:6],
- xlabel='time / days',
- ylabel='amount of people',
- y_log_scale=do_log,
- add_y_space=0.05,
- number_of_legend_columns=3,
- same_axes=False,
- ylim=(0, 1.85e7))
- plotter.cluster_plot(t,
- data[7:],
- [r'$S$', r'$I$', r'$R$'],
- (3, 3),
- (6,6),
- "state_sir_cluster_2",
- list(STATE_LOOKUP.values())[7:],
- xlabel='time / days',
- ylabel='amount of people',
- y_log_scale=do_log,
- add_y_space=0.03,
- number_of_legend_columns=3,
- same_axes=False,
- ylim=(0, 1.85e7))
- germany_data = np.genfromtxt(DS_DIR + f'SIR_RKI_Germany_1_14.csv', delimiter=',')[1:]
- in_text_data[1] = germany_data
- plotter.plot(t,
- germany_data,
- [r'$S$', r'$I$', r'$R$'],
- 'germany_single_sir',
- 'Germany',
- (6,6),
- plot_legend=False,
- xlabel='time / days',
- ylabel='amount of people',)
- plotter.cluster_plot(t,
- in_text_data,
- [r'$S$', r'$I$', r'$R$'],
- (2, 3),
- (6, 6),
- "in_text_SIR",
- ["synthetic SIR data", "Germany", 'Schleswig Holstein', 'Berlin', 'Thuringia'],
- xlabel='time / days',
- ylabel='amount of people',
- legend_loc=(0.51, 0.8),
- add_y_space=0,
- same_axes=False,
- free_axis=(0, 1),
- plot_all_labels=False)
|