transform_data.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import numpy as np
  2. import pandas as pd
  3. from datetime import timedelta
  4. from src.plotter import Plotter
  5. def transform_general_to_SIR(plotter:Plotter, dataset_path='datasets/COVID-19-Todesfaelle_in_Deutschland/', plot_name='', plot_title='', sample_rate=1, exclude=[], plot_size=(12,6), yscale_log=False, plot_legend=True):
  6. """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
  7. Args:
  8. plotter (Plotter): Plotter object to plot dataset curves.
  9. dataset_path (str, optional): Path to the dataset directory. Defaults to 'datasets/COVID-19-Todesfaelle_in_Deutschland/'.
  10. plot_name (str, optional): Name of the plot file. Defaults to ''.
  11. plot_title (str, optional): Title of the plot. Defaults to ''.
  12. sample_rate (int, optional): Sample rate used to sample the timepoints. Defaults to 1.
  13. exclude (list, optional): List of groups that are to excluded from the plot. Defaults to [].
  14. plot_size (tuple, optional): Size of the plot in (x, y) format. Defaults to (12,6).
  15. yscale_log (bool, optional): Controls if the y axis of the plot will be scaled in log scale. Defaults to False.
  16. plot_legend (bool, optional): Controls if the legend is to be plotted. Defaults to True.
  17. """
  18. # read the data
  19. df = pd.read_csv(dataset_path + 'COVID-19-Todesfaelle_Deutschland.csv')
  20. df = df.drop(df.index[1200:])
  21. # population of germany at the end of 2019
  22. N = 83100000
  23. S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
  24. # S_0 = N - I_0
  25. S[0] = N - df['Faelle_gesamt'][0]
  26. # I_0 = overall cases at the day - overall death cases at the day
  27. I[0] = df['Faelle_gesamt'][0] - df['Todesfaelle_gesamt'][0]
  28. # R_0 = overall death cases at the day
  29. R[0] = df['Todesfaelle_gesamt'][0]
  30. # the recovery time is 14 days
  31. recovery_queue = np.zeros(14)
  32. for day in range(1, df.shape[0]):
  33. infections = df['Faelle_gesamt'][day] - df['Faelle_gesamt'][day-1]
  34. deaths = df['Todesfaelle_neu'][day]
  35. recoveries = recovery_queue[0]
  36. S[day] = S[day-1] - infections
  37. I[day] = I[day-1] + infections - deaths - recoveries
  38. R[day] = R[day-1] + deaths + recoveries
  39. # update recovery queue
  40. if I[day] < 0:
  41. recovery_queue[-1] -= I[day]
  42. I[day] = 0
  43. recovery_queue[:-1] = recovery_queue[1:]
  44. recovery_queue[-1] = infections
  45. t = np.arange(0, df.shape[0], 1)
  46. if plotter != None:
  47. # plot graphs
  48. plots = []
  49. labels = []
  50. if 'S' not in exclude:
  51. plots.append(S)
  52. labels.append('S')
  53. if 'I' not in exclude:
  54. plots.append(I)
  55. labels.append('I')
  56. if 'R' not in exclude:
  57. plots.append(R)
  58. labels.append('R')
  59. plotter.plot(t, plots, labels, plot_name, plot_title, plot_size, y_log_scale=yscale_log, plot_legend=plot_legend, xlabel='time / days', ylabel='amount of poeple')
  60. COVID_Data = np.asarray([t[0::sample_rate],
  61. S[0::sample_rate],
  62. I[0::sample_rate],
  63. R[0::sample_rate]])
  64. np.savetxt(f"datasets/SIR_RKI_{sample_rate}.csv", COVID_Data, delimiter=",")
  65. def get_state_cases(county_id, state_id):
  66. id = county_id // 1000
  67. return id == state_id
  68. def state_based_data(plotter:Plotter, state_name:str, time_range=1200, sample_rate=1, dataset_path='datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv'):
  69. """Transforms the RKI infection cases dataset to a SIR dataset.
  70. Args:
  71. plotter (Plotter): Plotter object to plot dataset curves.
  72. state_name (str): Name of the state that is to be singled out in the new dataset.
  73. time_range (int, optional): Number of days that will be looked at in the new dataset. Defaults to 1200.
  74. sample_rate (int, optional): Sample rate used to sample the timepoints. Defaults to 1.
  75. dataset_path (str, optional): Path to the CSV file, where the data is stored. Defaults to 'datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv'.
  76. """
  77. df = pd.read_csv(dataset_path)
  78. state_lookup = {'Schleswig Holstein' : (1, 2897000),
  79. 'Hamburg' : (2, 1841000),
  80. 'Niedersachsen' : (3, 7982000),
  81. 'Bremen' : (4, 569352),
  82. 'Nordrhein-Westfalen' : (5, 17930000),
  83. 'Hessen' : (6, 6266000),
  84. 'Rheinland-Pfalz' : (7, 4085000),
  85. 'Baden-Württemberg' : (8, 11070000),
  86. 'Bayern' : (9, 13080000),
  87. 'Saarland' : (10, 990509),
  88. 'Berlin' : (11, 3645000),
  89. 'Brandenburg' : (12, 2641000),
  90. 'Mecklenburg-Vorpommern' : (13, 1610000),
  91. 'Sachsen' : (14, 4078000),
  92. 'Sachsen-Anhalt' : (15, 2208000),
  93. 'Thüringen' : (16, 2143000)}
  94. state_ID, N = state_lookup[state_name]
  95. # single out a state
  96. state_IDs = df['IdLandkreis'] // 1000
  97. state_df = df.loc[state_IDs == state_ID]
  98. # sort entries by state
  99. state_df = state_df.sort_values('Refdatum')
  100. state_df = state_df.reset_index(drop=True)
  101. # collect cases
  102. infections = np.zeros(time_range)
  103. dead = np.zeros(time_range)
  104. recovered = np.zeros(time_range)
  105. entry_idx = 0
  106. day = 0
  107. date = state_df['Refdatum'][entry_idx]
  108. # check for each date all entries
  109. while day < time_range:
  110. # use the date sorted characteristic and take all entries with current date
  111. while state_df['Refdatum'][entry_idx] == date:
  112. # TODO use further parameters
  113. infections[day] += state_df['AnzahlFall'][entry_idx]
  114. dead[day] += state_df['AnzahlTodesfall'][entry_idx]
  115. recovered[day] += state_df['AnzahlGenesen'][entry_idx]
  116. entry_idx += 1
  117. # move day index by difference between the current and next date
  118. day += (pd.to_datetime(state_df['Refdatum'][entry_idx])-pd.to_datetime(date)).days
  119. date = state_df['Refdatum'][entry_idx]
  120. S = np.zeros(time_range)
  121. I = np.zeros(time_range)
  122. S[0] = N - infections[0]
  123. I[0] = infections[0]
  124. for day in range(1, time_range):
  125. S[day] = S[day-1] - infections[day]
  126. I[day] = I[day-1] + infections[day] - I[day-1]/14
  127. t = np.arange(0, time_range, 1)
  128. plotter.plot(t, [S, I], ['S', 'I'], state_name.replace(' ', '_').replace('-', '_'), state_name+' SI', (6,6), xlabel='time / days', ylabel='amount of people')
  129. COVID_Data = np.asarray([t[0::sample_rate],
  130. S[0::sample_rate],
  131. I[0::sample_rate]])
  132. np.savetxt(f"datasets/SIR_RKI_{state_name.replace(' ', '_').replace('-', '_')}_{sample_rate}.csv", COVID_Data, delimiter=",")