| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- import torch
- from .dataset import PandemicDataset
- class PandemicProblem:
- def __init__(self, data: PandemicDataset) -> None:
- """Parent class for all pandemic problem classes. Holding the function, that calculates the residuals of the differential system.
- Args:
- data (PandemicDataset): Dataset holding the time values used.
- """
- self._data = data
- self._device_name = data.device_name
- self._gradients = None
- def residual(self):
- """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
- """
- assert self._gradients != None, 'Gradientmatrix need to be defined'
-
- def def_grad_matrix(self, number:int):
- assert self._gradients == None, 'Gradientmatrix is already defined'
- self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)]
- for i in range(number):
- self._gradients[i][:, i] = 1
- class SIRProblem(PandemicProblem):
- def __init__(self, data: PandemicDataset):
- super().__init__(data)
- def residual(self, SIR_pred, alpha, beta):
- super().residual()
- SIR_pred.backward(self._gradients[0], retain_graph=True)
- dSdt = self._data.t_raw.grad.clone()
- self._data.t_raw.grad.zero_()
- SIR_pred.backward(self._gradients[1], retain_graph=True)
- dIdt = self._data.t_raw.grad.clone()
- self._data.t_raw.grad.zero_()
- SIR_pred.backward(self._gradients[2], retain_graph=True)
- dRdt = self._data.t_raw.grad.clone()
- self._data.t_raw.grad.zero_()
- S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
- S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S'))
- I_residual = dIdt - (beta * ((S * I) / self._data.N) - alpha * I) / (self._data.get_max('I') - self._data.get_min('I'))
- R_residual = dRdt - (alpha * I) / (self._data.get_max('R') - self._data.get_min('R'))
- return S_residual, I_residual, R_residual
- class ReducedSIRProblem(PandemicProblem):
- def __init__(self, data: PandemicDataset, alpha:float):
- super().__init__(data)
- self.alpha = alpha
- def residual(self, SI_pred):
- super().residual()
- SI_pred.backward(self._gradients[0], retain_graph=True)
- dSdt = self._data.t_raw.grad.clone()
- self._data.t_raw.grad.zero_()
- SI_pred.backward(self._gradients[1], retain_graph=True)
- dIdt = self._data.t_raw.grad.clone()
- self._data.t_raw.grad.zero_()
- _, I = self._data.get_denormalized_data([SI_pred[:, 0], SI_pred[:, 1]])
- R_t = SI_pred[:, 2]
- # I = SI_pred[:, 1]
- S_residual = dSdt - (-self.alpha * R_t * I)
- I_residual = dIdt - (self.alpha * (R_t - 1) * I)
- # print(f'\nTrue:\tI_min: {I.min()}, I_max: {I.max()}\nNorm:\tI_min: {SI_pred[:, 1].min()}, I_max: {SI_pred[:, 1].max()}\nResidual:\t{torch.mean(torch.square(I_residual))}')
- return S_residual, I_residual
|