|
@@ -1,6 +1,7 @@
|
|
|
import torch
|
|
|
from .dataset import PandemicDataset
|
|
|
|
|
|
+
|
|
|
class PandemicProblem:
|
|
|
def __init__(self, data: PandemicDataset) -> None:
|
|
|
"""Parent class for all pandemic problem classes. Holding the function, that calculates the residuals of the differential system.
|
|
@@ -18,14 +19,14 @@ class PandemicProblem:
|
|
|
"""NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
|
|
|
"""
|
|
|
assert self._gradients != None, 'Gradientmatrix need to be defined'
|
|
|
-
|
|
|
|
|
|
- def def_grad_matrix(self, number:int):
|
|
|
+ def def_grad_matrix(self, number: int):
|
|
|
assert self._gradients == None, 'Gradientmatrix is already defined'
|
|
|
self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)]
|
|
|
for i in range(number):
|
|
|
self._gradients[i][:, i] = 1
|
|
|
|
|
|
+
|
|
|
class SIRProblem(PandemicProblem):
|
|
|
def __init__(self, data: PandemicDataset):
|
|
|
super().__init__(data)
|
|
@@ -53,8 +54,36 @@ class SIRProblem(PandemicProblem):
|
|
|
return S_residual, I_residual, R_residual
|
|
|
|
|
|
|
|
|
+class SIRAlphaProblem(PandemicProblem):
|
|
|
+ def __init__(self, data: PandemicDataset, alpha):
|
|
|
+ super().__init__(data)
|
|
|
+ self.alpha = alpha
|
|
|
+
|
|
|
+ def residual(self, SIR_pred, beta):
|
|
|
+ super().residual()
|
|
|
+ SIR_pred.backward(self._gradients[0], retain_graph=True)
|
|
|
+ dSdt = self._data.t_raw.grad.clone()
|
|
|
+ self._data.t_raw.grad.zero_()
|
|
|
+
|
|
|
+ SIR_pred.backward(self._gradients[1], retain_graph=True)
|
|
|
+ dIdt = self._data.t_raw.grad.clone()
|
|
|
+ self._data.t_raw.grad.zero_()
|
|
|
+
|
|
|
+ SIR_pred.backward(self._gradients[2], retain_graph=True)
|
|
|
+ dRdt = self._data.t_raw.grad.clone()
|
|
|
+ self._data.t_raw.grad.zero_()
|
|
|
+
|
|
|
+ S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
|
|
|
+
|
|
|
+ S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S'))
|
|
|
+ I_residual = dIdt - (beta * ((S * I) / self._data.N) - self.alpha * I) / (self._data.get_max('I') - self._data.get_min('I'))
|
|
|
+ R_residual = dRdt - (self.alpha * I) / (self._data.get_max('R') - self._data.get_min('R'))
|
|
|
+
|
|
|
+ return S_residual, I_residual, R_residual
|
|
|
+
|
|
|
+
|
|
|
class ReducedSIRProblem(PandemicProblem):
|
|
|
- def __init__(self, data: PandemicDataset, alpha:float):
|
|
|
+ def __init__(self, data: PandemicDataset, alpha: float):
|
|
|
super().__init__(data)
|
|
|
self.alpha = alpha
|
|
|
|
|
@@ -72,4 +101,3 @@ class ReducedSIRProblem(PandemicProblem):
|
|
|
|
|
|
I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
|
|
|
return I_residual
|
|
|
-
|