Parcourir la source

implement model

phillip.rothenbeck il y a 10 mois
Parent
commit
e96794c310

+ 6 - 3
src/dinn.py

@@ -158,6 +158,8 @@ class DINN:
         match scheduler_name:
             case 'CyclicLR':
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
+            case 'ConstantLR':
+                self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
             case 'LinearLR':
                 self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
             case 'PolynomialLR':
@@ -302,16 +304,17 @@ class DINN:
                                   plot_legend=False)
 
     def plot_state_variables(self):
+        prediction = self.model(self.data.t_batch)
+        groups = [prediction[:, i] for i in range(self.data.number_groups)]
+        fore_background = [0] + [1 for _ in groups]
         for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
-            prediction = self.model(self.data.t_batch)
-            groups = [prediction[:, i] for i in range(self.data.number_groups)]
             t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
             self.plotter.plot(t,
                               [prediction[:, i]] + groups,
                               [self.__state_variables[i-self.data.number_groups]] + self.data.group_names,
                               f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
                               self.__state_variables[i-self.data.number_groups],
-                              is_background=[0, 1, 1],
+                              is_background=fore_background,
                               figure_shape=(12, 6),
                               plot_legend=True,
                               xlabel='time / days')

+ 6 - 13
src/problem.py

@@ -60,22 +60,15 @@ class ReducedSIRProblem(PandemicProblem):
 
     def residual(self, SI_pred):
         super().residual()
-        SI_pred.backward(self._gradients[0], retain_graph=True)
-        dSdt = self._data.t_raw.grad.clone()
-        self._data.t_raw.grad.zero_()
 
-        SI_pred.backward(self._gradients[1], retain_graph=True)
+        SI_pred.backward(self._gradients[0], retain_graph=True)
         dIdt = self._data.t_raw.grad.clone()
         self._data.t_raw.grad.zero_()
+        
+        I = SI_pred[:, 0]
+        R_t = SI_pred[:, 1]
 
-        _, I = self._data.get_denormalized_data([SI_pred[:, 0], SI_pred[:, 1]])
-        R_t = SI_pred[:, 2]
-        # I = SI_pred[:, 1]
-
-        S_residual = dSdt - (-self.alpha * R_t * I)
-        I_residual = dIdt - (self.alpha * (R_t - 1) * I)
-
-        # print(f'\nTrue:\tI_min: {I.min()}, I_max: {I.max()}\nNorm:\tI_min: {SI_pred[:, 1].min()}, I_max: {SI_pred[:, 1].max()}\nResidual:\t{torch.mean(torch.square(I_residual))}')
+        I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
 
-        return S_residual, I_residual
+        return I_residual
 

Fichier diff supprimé car celui-ci est trop grand
+ 40 - 39
synth_dinn_reduced_sir.ipynb


BIN
visualizations/synth_sir_R_t.png


BIN
visualizations/synth_sir_animation.gif


BIN
visualizations/synth_sir_loss.png


Certains fichiers n'ont pas été affichés car il y a eu trop de fichiers modifiés dans ce diff