|
|
@@ -8,17 +8,54 @@ class PandemicProblem:
|
|
|
Args:
|
|
|
data (PandemicDataset): Dataset holding the time values used.
|
|
|
"""
|
|
|
+
|
|
|
+ self.data = data
|
|
|
+ self.device_name = data.device_name
|
|
|
+
|
|
|
#store the gradients for each group
|
|
|
- self.gradients = [torch.zeros((len(data.t_raw), data.number_groups)) for _ in range(data.number_groups)]
|
|
|
+ self.gradients = [torch.zeros((len(data.t_raw), data.number_groups), device=self.device_name) for _ in range(data.number_groups)]
|
|
|
|
|
|
for i in range(data.number_groups):
|
|
|
self.gradients[i][:, i] = 1
|
|
|
|
|
|
- def to_device(self, device):
|
|
|
- for i in range(len(self.gradients)):
|
|
|
- self.gradients[i] = self.gradients[i].to(device)
|
|
|
+ def residual(self):
|
|
|
+ """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
|
|
|
+ """
|
|
|
+ pass
|
|
|
|
|
|
- def residual():
|
|
|
+ def denormalization(self):
|
|
|
"""NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
|
|
|
"""
|
|
|
- pass
|
|
|
+ pass
|
|
|
+
|
|
|
+class SIRProblem(PandemicProblem):
|
|
|
+ def __init__(self, data: PandemicDataset):
|
|
|
+ super().__init__(data)
|
|
|
+
|
|
|
+ def residual(self, SIR_pred, alpha, beta):
|
|
|
+ SIR_pred.backward(self.gradients[0], retain_graph=True)
|
|
|
+ dSdt = self.data.t_raw.grad.clone()
|
|
|
+ self.data.t_raw.grad.zero_()
|
|
|
+
|
|
|
+ SIR_pred.backward(self.gradients[1], retain_graph=True)
|
|
|
+ dIdt = self.data.t_raw.grad.clone()
|
|
|
+ self.data.t_raw.grad.zero_()
|
|
|
+
|
|
|
+ SIR_pred.backward(self.gradients[2], retain_graph=True)
|
|
|
+ dRdt = self.data.t_raw.grad.clone()
|
|
|
+ self.data.t_raw.grad.zero_()
|
|
|
+
|
|
|
+ S, I, R = self.denormalization(SIR_pred)
|
|
|
+
|
|
|
+ S_residual = dSdt - (-beta * ((S * I) / self.data.N)) / (self.data.get_max('S') - self.data.get_min('S'))
|
|
|
+ I_residual = dIdt - (beta * ((S * I) / self.data.N) - alpha * I) / (self.data.get_max('I') - self.data.get_min('I'))
|
|
|
+ R_residual = dRdt - (alpha * I) / (self.data.get_max('R') - self.data.get_min('R'))
|
|
|
+
|
|
|
+ return S_residual, I_residual, R_residual
|
|
|
+
|
|
|
+ def denormalization(self, predictions):
|
|
|
+ S_pred, I_pred, R_pred = predictions[:, 0], predictions[:, 1], predictions[:, 2]
|
|
|
+ S = self.data.get_min('S') + (self.data.get_max('S') - self.data.get_min('S')) * S_pred
|
|
|
+ I = self.data.get_min('I') + (self.data.get_max('I') - self.data.get_min('I')) * I_pred
|
|
|
+ R = self.data.get_min('R') + (self.data.get_max('R') - self.data.get_min('R')) * R_pred
|
|
|
+ return S, I, R
|