Bladeren bron

add nomalization

phillip.rothenbeck 1 jaar geleden
bovenliggende
commit
a44026842f
1 gewijzigde bestanden met toevoegingen van 11 en 3 verwijderingen
  1. 11 3
      src/dinn.py

+ 11 - 3
src/dinn.py

@@ -18,7 +18,9 @@ class DINN:
                      input_size: int,
                      hidden_size: int,
                      hidden_layers: int, 
-                     activation_layer) -> None:
+                     activation_layer, 
+                     t_init,
+                     t_final) -> None:
             """Neural Network
 
             Args:
@@ -33,9 +35,14 @@ class DINN:
             self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
             self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
             self.output = torch.nn.Linear(hidden_size, output_size)
+            
+            self.__t_init = t_init
+            self.__t_final = t_final
 
         def forward(self, t):
-            x = self.input(t)
+            # normalize input
+            t_scaled = (t - self.__t_init) / (self.__t_final - self.__t_init)
+            x = self.input(t_scaled)
             x = self.hidden(x)
             x = self.output(x)
             return x
@@ -59,6 +66,7 @@ class DINN:
             data (PandemicDataset): Data collected showing the course of the pandemic
             parameter_list (list): List of the parameter names(strings), that are supposed to be found.
             problem (PandemicProblem): Problem class implementing the calculation of the residuals.
+            plotter (Plotter): Plotter object to plot dataset curves.
             parameter_regulator (optional): Function to force the parameters to be in a certain range. Defaults to torch.tanh.
             input_size (int, optional): Number of the input nodes of the NN. Defaults to 1.
             hidden_size (int, optional): Number of the hidden nodes of the NN. Defaults to 20.
@@ -70,7 +78,7 @@ class DINN:
         self.device_name = data.device_name
         self.plotter = plotter
 
-        self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer)
+        self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer, data.t_init, data.t_final)
         self.model = self.model.to(self.device)
         self.data = data
         self.parameter_regulator = parameter_regulator