import torch from .dataset import PandemicDataset class PandemicProblem: def __init__(self, data: PandemicDataset) -> None: """Parent class for all pandemic problem classes. Holding the function, that calculates the residuals of the differential system. Args: data (PandemicDataset): Dataset holding the time values used. """ self._data = data self._device_name = data.device_name self._gradients = None def residual(self): """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS """ assert self._gradients != None, 'Gradientmatrix need to be defined' def def_grad_matrix(self, number: int): assert self._gradients == None, 'Gradientmatrix is already defined' self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)] for i in range(number): self._gradients[i][:, i] = 1 class SIRProblem(PandemicProblem): def __init__(self, data: PandemicDataset): super().__init__(data) def residual(self, SIR_pred, alpha, beta): super().residual() SIR_pred.backward(self._gradients[0], retain_graph=True) dSdt = self._data.t_raw.grad.clone() self._data.t_raw.grad.zero_() SIR_pred.backward(self._gradients[1], retain_graph=True) dIdt = self._data.t_raw.grad.clone() self._data.t_raw.grad.zero_() SIR_pred.backward(self._gradients[2], retain_graph=True) dRdt = self._data.t_raw.grad.clone() self._data.t_raw.grad.zero_() S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]]) S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S')) I_residual = dIdt - (beta * ((S * I) / self._data.N) - alpha * I) / (self._data.get_max('I') - self._data.get_min('I')) R_residual = dRdt - (alpha * I) / (self._data.get_max('R') - self._data.get_min('R')) return S_residual, I_residual, R_residual class SIRAlphaProblem(PandemicProblem): def __init__(self, data: PandemicDataset, alpha): super().__init__(data) self.alpha = alpha def residual(self, SIR_pred, beta): super().residual() SIR_pred.backward(self._gradients[0], retain_graph=True) dSdt = self._data.t_raw.grad.clone() self._data.t_raw.grad.zero_() SIR_pred.backward(self._gradients[1], retain_graph=True) dIdt = self._data.t_raw.grad.clone() self._data.t_raw.grad.zero_() SIR_pred.backward(self._gradients[2], retain_graph=True) dRdt = self._data.t_raw.grad.clone() self._data.t_raw.grad.zero_() S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]]) S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S')) I_residual = dIdt - (beta * ((S * I) / self._data.N) - self.alpha * I) / (self._data.get_max('I') - self._data.get_min('I')) R_residual = dRdt - (self.alpha * I) / (self._data.get_max('R') - self._data.get_min('R')) return S_residual, I_residual, R_residual class ReducedSIRProblem(PandemicProblem): def __init__(self, data: PandemicDataset, alpha: float): super().__init__(data) self.alpha = alpha def residual(self, I_pred): super().residual() I_pred.backward(self._gradients[0], retain_graph=True) dIdt = self._data.t_scaled.grad.clone() self._data.t_scaled.grad.zero_() I = I_pred[:, 0] R_t = I_pred[:, 1] # dIdt = torch.autograd.grad(I, self._data.t_scaled, torch.ones_like(I), create_graph=True)[0] I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I) return I_residual