Ver código fonte

move normalization to dataset

phillip.rothenbeck 1 ano atrás
pai
commit
b26032e6c1
3 arquivos alterados com 58 adições e 33 exclusões
  1. 30 8
      src/dataset.py
  2. 11 6
      src/dinn.py
  3. 17 19
      src/problem.py

+ 30 - 8
src/dataset.py

@@ -13,17 +13,23 @@ class PandemicDataset:
             group_names (list): Names of the groups, in which the population is split.
             N (int): Size of the population.
             t (np.array): Array of timesteps.
-            *groups (np.array): Arrays of size data for each group for each timestep..
+            *groups (np.array): Arrays of size data for each group for each timestep.
         """
 
         if torch.cuda.is_available():
             self.device_name = 'cuda'
         else:
             self.device_name = 'cpu'
+            
         self.name = name
         self.N = N
+        self.t_init = t.min()
+        self.t_final = t.max()
+
         self.t_raw = torch.tensor(t, requires_grad=True, device=self.device_name)
-        self.t_batch = self.t_raw.view(-1, 1).float()
+        self.t_norm = torch.tensor((t - self.t_init) / (self.t_final - self.t_init), requires_grad=True, device=self.device_name)
+
+        self.t_batch = self.t_norm.view(-1, 1).float()
 
         self.__group_dict = {}
         for i, name in enumerate(group_names):
@@ -37,10 +43,29 @@ class PandemicDataset:
         self.__maxs = [torch.max(group) for group in self.__groups]
         self.__norms = [(self.__groups[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(len(groups))]
 
-        self.number_groups = len(groups)
-
-    def get_data(self):
+    @property
+    def number_groups(self):
+        return len(self.__group_names)
+    
+    @property
+    def data(self):
         return self.__groups
+    
+    @property
+    def group_names(self):
+        return self.__group_names
+    
+    @property
+    def normalization_differantial(self):
+        return 1 / (self.t_final - self.t_init)
+
+    def get_normalized_data(self, data:list):
+        assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
+        return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
+    
+    def get_denormalized_data(self, normalized_data:list):
+        assert len(normalized_data) == self.number_groups, f'normalized_data parameter needs same length as there are groups in the dataset ({self.number_groups})'
+        return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * normalized_data[i]) for i in range(self.number_groups)]
 
     def get_group(self, name:str):
         return self.__groups[self.__group_dict[name]]
@@ -54,6 +79,3 @@ class PandemicDataset:
     def get_norm(self, name:str):
         return self.__norms[self.__group_dict[name]]
     
-    def get_group_names(self):
-        return self.__group_names
-    

+ 11 - 6
src/dinn.py

@@ -83,6 +83,8 @@ class DINN:
         self.epochs = None
 
         self.losses = np.zeros(1)
+        self.obs_losses = np.zeros(1)
+        self.physics_losses = np.zeros(1)
         self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
 
         self.frames = []
@@ -139,6 +141,8 @@ class DINN:
 
         # arrays to hold values for plotting
         self.losses = np.zeros(epochs)
+        self.obs_losses = np.zeros(epochs)
+        self.physics_losses = np.zeros(epochs)
         self.parameters = [np.zeros(epochs) for _ in self.parameters]
 
         for epoch in range(epochs):
@@ -155,7 +159,7 @@ class DINN:
 
             # calculate loss from the dataset
             loss_obs = 0
-            for i, group in enumerate(self.data.get_group_names()):
+            for i, group in enumerate(self.data.group_names):
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
             
             loss = loss_physics + loss_obs
@@ -166,6 +170,8 @@ class DINN:
 
             # append values for plotting
             self.losses[epoch] = loss.item()
+            self.obs_losses[epoch] = loss_obs.item()
+            self.physics_losses[epoch] = loss_physics.item()
             for i, parameter in enumerate(self.parameters_tilda.items()):
                 self.parameters[i][epoch] = self.get_regulated_param(parameter[0]).item()
 
@@ -174,11 +180,10 @@ class DINN:
                 # prediction
                 prediction = self.model(self.data.t_batch)
                 t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
-                
-                groups = self.problem.denormalization(prediction)
+                groups = self.data.get_denormalized_data([prediction[:, 0], prediction[:, 1], prediction[:, 2]])
                 self.plotter.plot(t, 
-                                  groups + tuple(self.data.get_data()), 
-                                  [name + '_pred' for name in self.data.get_group_names()] + [name + '_true' for name in self.data.get_group_names()],
+                                  list(groups) + list(self.data.data), 
+                                  [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names],
                                   'frame',
                                   f'epoch {epoch}',
                                   figure_shape=(12, 6),
@@ -215,7 +220,7 @@ class DINN:
         epochs = np.arange(0, self.epochs, 1)
 
         # plot loss
-        self.plotter.plot(epochs, [self.losses], ['loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=False, xlabel='epochs')
+        self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss', 'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
         
         # plot parameters
         for i, parameter in enumerate(self.parameters):

+ 17 - 19
src/problem.py

@@ -23,39 +23,37 @@ class PandemicProblem:
         """
         pass
 
-    def denormalization(self):
-        """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
-        """
-        pass
-
 class SIRProblem(PandemicProblem):
     def __init__(self, data: PandemicDataset):
         super().__init__(data)
 
     def residual(self, SIR_pred, alpha, beta):
+        S_pred, I_pred, R_pred = SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]
+
+        # dSdt = torch.autograd.grad(S_pred, self.data.t_raw, torch.ones_like(S_pred), create_graph=True)[0]
         SIR_pred.backward(self.gradients[0], retain_graph=True)
-        dSdt = self.data.t_raw.grad.clone()
-        self.data.t_raw.grad.zero_()
+        dSdt_norm = self.data.t_norm.grad.clone()
+        self.data.t_norm.grad.zero_()
 
+        # dIdt = torch.autograd.grad(I_pred, self.data.t_raw, torch.ones_like(I_pred), create_graph=True)[0]
         SIR_pred.backward(self.gradients[1], retain_graph=True)
-        dIdt = self.data.t_raw.grad.clone()
-        self.data.t_raw.grad.zero_()
+        dIdt_norm = self.data.t_norm.grad.clone()
+        self.data.t_norm.grad.zero_()
 
+        # dRdt = torch.autograd.grad(R_pred, self.data.t_raw, torch.ones_like(R_pred), create_graph=True)[0]
         SIR_pred.backward(self.gradients[2], retain_graph=True)
-        dRdt = self.data.t_raw.grad.clone()
-        self.data.t_raw.grad.zero_()
+        dRdt_norm = self.data.t_norm.grad.clone()
+        self.data.t_norm.grad.zero_()
+        
+        S, I, R = self.data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
+        # print(f'dSdt: {dSdt}, dIdt: {dIdt}, dRdt: {dRdt}')
         
-        S, I, R = self.denormalization(SIR_pred)
+        dSdt = dSdt_norm * self.data.normalization_differantial
+        dIdt = dRdt_norm * self.data.normalization_differantial
+        dRdt = dIdt_norm * self.data.normalization_differantial
 
         S_residual = dSdt - (-beta * ((S * I) / self.data.N)) / (self.data.get_max('S') - self.data.get_min('S'))
         I_residual = dIdt - (beta * ((S * I) / self.data.N) - alpha * I) / (self.data.get_max('I') - self.data.get_min('I'))
         R_residual = dRdt - (alpha * I) / (self.data.get_max('R') - self.data.get_min('R'))
 
         return S_residual, I_residual, R_residual
-    
-    def denormalization(self, predictions):
-        S_pred, I_pred, R_pred = predictions[:, 0], predictions[:, 1], predictions[:, 2]
-        S = self.data.get_min('S') + (self.data.get_max('S') - self.data.get_min('S')) * S_pred
-        I = self.data.get_min('I') + (self.data.get_max('I') - self.data.get_min('I')) * I_pred
-        R = self.data.get_min('R') + (self.data.get_max('R') - self.data.get_min('R')) * R_pred
-        return S, I, R