problem.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import torch
  2. from .dataset import PandemicDataset
  3. class PandemicProblem:
  4. def __init__(self, data: PandemicDataset) -> None:
  5. """Parent class for all pandemic problem classes. Holding the function, that calculates the residuals of the differential system.
  6. Args:
  7. data (PandemicDataset): Dataset holding the time values used.
  8. """
  9. self._data = data
  10. self._device_name = data.device_name
  11. self._gradients = None
  12. def residual(self):
  13. """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
  14. """
  15. assert self._gradients != None, 'Gradientmatrix need to be defined'
  16. def def_grad_matrix(self, number: int):
  17. assert self._gradients == None, 'Gradientmatrix is already defined'
  18. self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)]
  19. for i in range(number):
  20. self._gradients[i][:, i] = 1
  21. class SIRProblem(PandemicProblem):
  22. def __init__(self, data: PandemicDataset):
  23. super().__init__(data)
  24. def residual(self, SIR_pred, alpha, beta):
  25. super().residual()
  26. SIR_pred.backward(self._gradients[0], retain_graph=True)
  27. dSdt = self._data.t_raw.grad.clone()
  28. self._data.t_raw.grad.zero_()
  29. SIR_pred.backward(self._gradients[1], retain_graph=True)
  30. dIdt = self._data.t_raw.grad.clone()
  31. self._data.t_raw.grad.zero_()
  32. SIR_pred.backward(self._gradients[2], retain_graph=True)
  33. dRdt = self._data.t_raw.grad.clone()
  34. self._data.t_raw.grad.zero_()
  35. S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
  36. S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S'))
  37. I_residual = dIdt - (beta * ((S * I) / self._data.N) - alpha * I) / (self._data.get_max('I') - self._data.get_min('I'))
  38. R_residual = dRdt - (alpha * I) / (self._data.get_max('R') - self._data.get_min('R'))
  39. return S_residual, I_residual, R_residual
  40. class SIRAlphaProblem(PandemicProblem):
  41. def __init__(self, data: PandemicDataset, alpha):
  42. super().__init__(data)
  43. self.alpha = alpha
  44. def residual(self, SIR_pred, beta):
  45. super().residual()
  46. SIR_pred.backward(self._gradients[0], retain_graph=True)
  47. dSdt = self._data.t_raw.grad.clone()
  48. self._data.t_raw.grad.zero_()
  49. SIR_pred.backward(self._gradients[1], retain_graph=True)
  50. dIdt = self._data.t_raw.grad.clone()
  51. self._data.t_raw.grad.zero_()
  52. SIR_pred.backward(self._gradients[2], retain_graph=True)
  53. dRdt = self._data.t_raw.grad.clone()
  54. self._data.t_raw.grad.zero_()
  55. S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
  56. S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S'))
  57. I_residual = dIdt - (beta * ((S * I) / self._data.N) - self.alpha * I) / (self._data.get_max('I') - self._data.get_min('I'))
  58. R_residual = dRdt - (self.alpha * I) / (self._data.get_max('R') - self._data.get_min('R'))
  59. return S_residual, I_residual, R_residual
  60. class ReducedSIRProblem(PandemicProblem):
  61. def __init__(self, data: PandemicDataset, alpha: float):
  62. super().__init__(data)
  63. self.alpha = alpha
  64. def residual(self, I_pred):
  65. super().residual()
  66. I_pred.backward(self._gradients[0], retain_graph=True)
  67. dIdt = self._data.t_scaled.grad.clone()
  68. self._data.t_scaled.grad.zero_()
  69. I = I_pred[:, 0]
  70. R_t = I_pred[:, 1]
  71. # dIdt = torch.autograd.grad(I, self._data.t_scaled, torch.ones_like(I), create_graph=True)[0]
  72. I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
  73. return I_residual