import torch from .dataset import PandemicDataset class PandemicProblem: def __init__(self, data: PandemicDataset) -> None: """Parent class for all pandemic problem classes. Holding the function, that calculates the residuals of the differential system. Args: data (PandemicDataset): Dataset holding the time values used. """ self._data = data self._device_name = data.device_name self._gradients = None def residual(self): """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS """ assert self._gradients != None, 'Gradientmatrix need to be defined' def def_grad_matrix(self, number:int): assert self._gradients == None, 'Gradientmatrix is already defined' self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)] for i in range(number): self._gradients[i][:, i] = 1 class SIRProblem(PandemicProblem): def __init__(self, data: PandemicDataset): super().__init__(data) def residual(self, SIR_pred, alpha, beta): super().residual() SIR_pred.backward(self._gradients[0], retain_graph=True) dSdt = self._data.t_raw.grad.clone() self._data.t_raw.grad.zero_() SIR_pred.backward(self._gradients[1], retain_graph=True) dIdt = self._data.t_raw.grad.clone() self._data.t_raw.grad.zero_() SIR_pred.backward(self._gradients[2], retain_graph=True) dRdt = self._data.t_raw.grad.clone() self._data.t_raw.grad.zero_() S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]]) S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S')) I_residual = dIdt - (beta * ((S * I) / self._data.N) - alpha * I) / (self._data.get_max('I') - self._data.get_min('I')) R_residual = dRdt - (alpha * I) / (self._data.get_max('R') - self._data.get_min('R')) return S_residual, I_residual, R_residual class ReducedSIRProblem(PandemicProblem): def __init__(self, data: PandemicDataset, alpha:float): super().__init__(data) self.alpha = alpha def residual(self, SI_pred): super().residual() SI_pred.backward(self._gradients[0], retain_graph=True) dSdt = self._data.t_raw.grad.clone() self._data.t_raw.grad.zero_() SI_pred.backward(self._gradients[1], retain_graph=True) dIdt = self._data.t_raw.grad.clone() self._data.t_raw.grad.zero_() _, I = self._data.get_denormalized_data([SI_pred[:, 0], SI_pred[:, 1]]) R_t = SI_pred[:, 2] # I = SI_pred[:, 1] S_residual = dSdt - (-self.alpha * R_t * I) I_residual = dIdt - (self.alpha * (R_t - 1) * I) # print(f'\nTrue:\tI_min: {I.min()}, I_max: {I.max()}\nNorm:\tI_min: {SI_pred[:, 1].min()}, I_max: {SI_pred[:, 1].max()}\nResidual:\t{torch.mean(torch.square(I_residual))}') return S_residual, I_residual