transform_SIR.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import numpy as np
  2. import pandas as pd
  3. from src.plotter import Plotter
  4. dataset_path = 'datasets/COVID-19-Todesfaelle_in_Deutschland/'
  5. FONT_COLOR = '#595959'
  6. SUSCEPTIBLE = '#6399f7'
  7. INFECTIOUS = '#f56262'
  8. REMOVED = '#83eb5e'
  9. def transform_general_data(plotter:Plotter, plot_name='', plot_title='', sample_rate=1, exclude=[], plot_size=(12,6), yscale_log=False, plot_legend=True):
  10. """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
  11. """
  12. # read the data
  13. df = pd.read_csv(dataset_path + 'COVID-19-Todesfaelle_Deutschland.csv')
  14. df = df.drop(df.index[1200:])
  15. # population of germany at the end of 2019
  16. N = 83100000
  17. S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
  18. # S_0 = N - I_0
  19. S[0] = N - df['Faelle_gesamt'][0]
  20. # I_0 = overall cases at the day - overall death cases at the day
  21. I[0] = df['Faelle_gesamt'][0] - df['Todesfaelle_gesamt'][0]
  22. # R_0 = overall death cases at the day
  23. R[0] = df['Todesfaelle_gesamt'][0]
  24. # the recovery time is 14 days
  25. recovery_queue = np.zeros(14)
  26. for day in range(1, df.shape[0]):
  27. infections = df['Faelle_gesamt'][day] - df['Faelle_gesamt'][day-1]
  28. deaths = df['Todesfaelle_neu'][day]
  29. recoveries = recovery_queue[0]
  30. S[day] = S[day-1] - infections
  31. I[day] = I[day-1] + infections - deaths - recoveries
  32. R[day] = R[day-1] + deaths + recoveries
  33. # update recovery queue
  34. if I[day] < 0:
  35. recovery_queue[-1] -= I[day]
  36. I[day] = 0
  37. recovery_queue[:-1] = recovery_queue[1:]
  38. recovery_queue[-1] = infections
  39. t = np.arange(0, df.shape[0], 1)
  40. if plotter != None:
  41. # plot graphs
  42. plots = []
  43. labels = []
  44. if 'S' not in exclude:
  45. plots.append(S)
  46. labels.append('S')
  47. if 'I' not in exclude:
  48. plots.append(I)
  49. labels.append('I')
  50. if 'R' not in exclude:
  51. plots.append(R)
  52. labels.append('R')
  53. plotter.plot(t, plots, labels, plot_name, plot_title, plot_size, y_log_scale=yscale_log, plot_legend=plot_legend, xlabel='time / days', ylabel='amount of poeple')
  54. COVID_Data = np.asarray([t[0::sample_rate],
  55. S[0::sample_rate],
  56. I[0::sample_rate],
  57. R[0::sample_rate]])
  58. np.savetxt(f"datasets/SIR_RKI_{sample_rate}.csv", COVID_Data, delimiter=",")