|
|
@@ -18,7 +18,9 @@ class DINN:
|
|
|
input_size: int,
|
|
|
hidden_size: int,
|
|
|
hidden_layers: int,
|
|
|
- activation_layer) -> None:
|
|
|
+ activation_layer,
|
|
|
+ t_init,
|
|
|
+ t_final) -> None:
|
|
|
"""Neural Network
|
|
|
|
|
|
Args:
|
|
|
@@ -33,9 +35,14 @@ class DINN:
|
|
|
self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
|
|
|
self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
|
|
|
self.output = torch.nn.Linear(hidden_size, output_size)
|
|
|
+
|
|
|
+ self.__t_init = t_init
|
|
|
+ self.__t_final = t_final
|
|
|
|
|
|
def forward(self, t):
|
|
|
- x = self.input(t)
|
|
|
+ # normalize input
|
|
|
+ t_scaled = (t - self.__t_init) / (self.__t_final - self.__t_init)
|
|
|
+ x = self.input(t_scaled)
|
|
|
x = self.hidden(x)
|
|
|
x = self.output(x)
|
|
|
return x
|
|
|
@@ -59,6 +66,7 @@ class DINN:
|
|
|
data (PandemicDataset): Data collected showing the course of the pandemic
|
|
|
parameter_list (list): List of the parameter names(strings), that are supposed to be found.
|
|
|
problem (PandemicProblem): Problem class implementing the calculation of the residuals.
|
|
|
+ plotter (Plotter): Plotter object to plot dataset curves.
|
|
|
parameter_regulator (optional): Function to force the parameters to be in a certain range. Defaults to torch.tanh.
|
|
|
input_size (int, optional): Number of the input nodes of the NN. Defaults to 1.
|
|
|
hidden_size (int, optional): Number of the hidden nodes of the NN. Defaults to 20.
|
|
|
@@ -70,7 +78,7 @@ class DINN:
|
|
|
self.device_name = data.device_name
|
|
|
self.plotter = plotter
|
|
|
|
|
|
- self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer)
|
|
|
+ self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer, data.t_init, data.t_final)
|
|
|
self.model = self.model.to(self.device)
|
|
|
self.data = data
|
|
|
self.parameter_regulator = parameter_regulator
|