import numpy as np import pandas as pd import matplotlib.pylab as plt from matplotlib import rcParams dataset_path = 'datasets/COVID-19-Todesfaelle_in_Deutschland/' FONT_COLOR = '#595959' SUSCEPTIBLE = '#6399f7' INFECTIOUS = '#f56262' REMOVED = '#83eb5e' def transform_general_data(plot_name: str, plot_title: str, exclude=[], plot_size=(12,6), yscale_log=False, plot_legend=True): """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset. """ # read the data df = pd.read_csv(dataset_path + 'COVID-19-Todesfaelle_Deutschland.csv') df = df.drop(df.index[1200:]) # population of germany at the end of 2019 N = 83100000 S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0]) # S_0 = N - I_0 S[0] = N - df['Faelle_gesamt'][0] # I_0 = overall cases at the day - overall death cases at the day I[0] = df['Faelle_gesamt'][0] - df['Todesfaelle_gesamt'][0] # R_0 = overall death cases at the day R[0] = df['Todesfaelle_gesamt'][0] # the recovery time is 14 days recovery_queue = np.zeros(14) for day in range(1, df.shape[0]): infections = df['Faelle_gesamt'][day] - df['Faelle_gesamt'][day-1] deaths = df['Todesfaelle_neu'][day] recoveries = recovery_queue[0] S[day] = S[day-1] - infections I[day] = I[day-1] + infections - deaths - recoveries R[day] = R[day-1] + deaths + recoveries # update recovery queue if I[day] < 0: recovery_queue[-1] -= I[day] I[day] = 0 recovery_queue[:-1] = recovery_queue[1:] recovery_queue[-1] = infections # plot graphs t = np.arange(0, df.shape[0], 1) rcParams['font.family'] = 'Comfortaa' rcParams['font.size'] = 12 rcParams['text.color'] = FONT_COLOR rcParams['axes.labelcolor'] = FONT_COLOR rcParams['xtick.color'] = FONT_COLOR rcParams['ytick.color'] = FONT_COLOR slide3 = plt.figure(figsize=plot_size) ax = slide3.add_subplot(111, facecolor='#dddddd', axisbelow=True) ax.set_facecolor('xkcd:white') if 'S' not in exclude: ax.plot(t, S, label='Susceptible', c=SUSCEPTIBLE, lw=3) if 'I' not in exclude: ax.plot(t, I, label='Infectious', c=INFECTIOUS, lw=3) if 'R' not in exclude: ax.plot(t, R, label='Removed', c=REMOVED, lw=3) if yscale_log: plt.yscale('log') plt.ylabel('amount of poeple') plt.xlabel('time / days') plt.title(plot_title) ax.yaxis.set_tick_params(length=0) if plot_legend: plt.legend() ax.yaxis.set_tick_params(length=0, which='both') ax.xaxis.set_tick_params(length=0, which='both') ax.grid(which='major', c='black', lw=0.2, ls='-') for spine in ('top', 'right', 'bottom', 'left'): ax.spines[spine].set_visible(False) slide3.savefig(f'visualizations/{plot_name}.png', transparent=True)