transform_SIR.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import numpy as np
  2. import pandas as pd
  3. import matplotlib.pylab as plt
  4. from matplotlib import rcParams
  5. dataset_path = 'datasets/COVID-19-Todesfaelle_in_Deutschland/'
  6. FONT_COLOR = '#595959'
  7. SUSCEPTIBLE = '#6399f7'
  8. INFECTIOUS = '#f56262'
  9. REMOVED = '#83eb5e'
  10. def transform_general_data(plot_name: str, plot_title: str, exclude=[], plot_size=(12,6), yscale_log=False, plot_legend=True):
  11. """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
  12. """
  13. # read the data
  14. df = pd.read_csv(dataset_path + 'COVID-19-Todesfaelle_Deutschland.csv')
  15. df = df.drop(df.index[1200:])
  16. # population of germany at the end of 2019
  17. N = 83100000
  18. S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
  19. # S_0 = N - I_0
  20. S[0] = N - df['Faelle_gesamt'][0]
  21. # I_0 = overall cases at the day - overall death cases at the day
  22. I[0] = df['Faelle_gesamt'][0] - df['Todesfaelle_gesamt'][0]
  23. # R_0 = overall death cases at the day
  24. R[0] = df['Todesfaelle_gesamt'][0]
  25. # the recovery time is 14 days
  26. recovery_queue = np.zeros(14)
  27. for day in range(1, df.shape[0]):
  28. infections = df['Faelle_gesamt'][day] - df['Faelle_gesamt'][day-1]
  29. deaths = df['Todesfaelle_neu'][day]
  30. recoveries = recovery_queue[0]
  31. S[day] = S[day-1] - infections
  32. I[day] = I[day-1] + infections - deaths - recoveries
  33. R[day] = R[day-1] + deaths + recoveries
  34. # update recovery queue
  35. if I[day] < 0:
  36. recovery_queue[-1] -= I[day]
  37. I[day] = 0
  38. recovery_queue[:-1] = recovery_queue[1:]
  39. recovery_queue[-1] = infections
  40. # plot graphs
  41. t = np.arange(0, df.shape[0], 1)
  42. rcParams['font.family'] = 'Comfortaa'
  43. rcParams['font.size'] = 12
  44. rcParams['text.color'] = FONT_COLOR
  45. rcParams['axes.labelcolor'] = FONT_COLOR
  46. rcParams['xtick.color'] = FONT_COLOR
  47. rcParams['ytick.color'] = FONT_COLOR
  48. slide3 = plt.figure(figsize=plot_size)
  49. ax = slide3.add_subplot(111, facecolor='#dddddd', axisbelow=True)
  50. ax.set_facecolor('xkcd:white')
  51. if 'S' not in exclude:
  52. ax.plot(t, S, label='Susceptible', c=SUSCEPTIBLE, lw=3)
  53. if 'I' not in exclude:
  54. ax.plot(t, I, label='Infectious', c=INFECTIOUS, lw=3)
  55. if 'R' not in exclude:
  56. ax.plot(t, R, label='Removed', c=REMOVED, lw=3)
  57. if yscale_log:
  58. plt.yscale('log')
  59. plt.ylabel('amount of poeple')
  60. plt.xlabel('time / days')
  61. plt.title(plot_title)
  62. ax.yaxis.set_tick_params(length=0)
  63. if plot_legend:
  64. plt.legend()
  65. ax.yaxis.set_tick_params(length=0, which='both')
  66. ax.xaxis.set_tick_params(length=0, which='both')
  67. ax.grid(which='major', c='black', lw=0.2, ls='-')
  68. for spine in ('top', 'right', 'bottom', 'left'):
  69. ax.spines[spine].set_visible(False)
  70. slide3.savefig(f'visualizations/{plot_name}.png', transparent=True)