dinn.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. import torch
  2. import csv
  3. import numpy as np
  4. from enum import Enum
  5. from .dataset import PandemicDataset
  6. from .problem import PandemicProblem
  7. from .plotter import Plotter
  8. class Optimizer(Enum):
  9. ADAM = 0
  10. class Scheduler(Enum):
  11. CYCLIC = 0
  12. CONSTANT = 1
  13. LINEAR = 2
  14. POLYNOMIAL = 3
  15. class Activation(Enum):
  16. LINEAR = 0
  17. POWER = 1
  18. def linear(x):
  19. return x
  20. def power(x):
  21. return torch.float_power(x, 2)
  22. class DINN:
  23. class NN(torch.nn.Module):
  24. def __init__(self,
  25. output_size: int,
  26. input_size: int,
  27. hidden_size: int,
  28. hidden_layers: int,
  29. activation_layer,
  30. t_init,
  31. t_final,
  32. output_activation_function=Activation.LINEAR,
  33. use_glorot_initialization=False,
  34. use_t_scaled=True) -> None:
  35. """Neural Network
  36. Args:
  37. output_size (int): number of outputs
  38. input_size (int): number of inputs
  39. hidden_size (int): number of hidden nodes per layer
  40. hidden_layers (int): number of hidden layers
  41. activation_layer (_type_): activation layer
  42. """
  43. super(DINN.NN, self).__init__()
  44. if output_activation_function == Activation.LINEAR:
  45. self.out_activation = linear
  46. elif output_activation_function == Activation.POWER:
  47. self.out_activation = power
  48. else:
  49. print('Set output activation to default: linear')
  50. self.out_activation = self.linear
  51. self.input = torch.nn.Sequential(torch.nn.Linear(
  52. input_size, hidden_size), activation_layer)
  53. self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(
  54. hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
  55. self.output = torch.nn.Linear(hidden_size, output_size)
  56. if use_glorot_initialization:
  57. torch.nn.init.xavier_uniform_(self.input[0].weight)
  58. for i in range(hidden_layers):
  59. torch.nn.init.xavier_uniform_(self.hidden[i][0].weight)
  60. torch.nn.init.xavier_uniform_(self.output.weight)
  61. self.__t_init = t_init
  62. self.__t_final = t_final
  63. self.__use_t_scaled = use_t_scaled
  64. def forward(self, t):
  65. # normalize input
  66. if self.__use_t_scaled:
  67. t_forward = (t - self.__t_init) / \
  68. (self.__t_final - self.__t_init)
  69. else:
  70. t_forward = t
  71. x = self.input(t_forward)
  72. x = self.hidden(x)
  73. x = self.output(x)
  74. return self.out_activation(x)
  75. def __init__(self,
  76. output_size: int,
  77. data: PandemicDataset,
  78. parameter_list: list,
  79. problem: PandemicProblem,
  80. plotter: Plotter,
  81. state_variables=[],
  82. parameter_regulator=torch.tanh,
  83. input_size=1,
  84. hidden_size=20,
  85. hidden_layers=7,
  86. activation_layer=torch.nn.ReLU(),
  87. activation_output=Activation.LINEAR,
  88. use_glorot_initialization=False) -> None:
  89. """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the
  90. parameters of a specific mathematical model.
  91. Args:
  92. output_size (int): Number of the output nodes of the NN.
  93. data (PandemicDataset): Data collected showing the course of the pandemic
  94. parameter_list (list): List of the parameter names(strings), that are supposed to be found.
  95. problem (PandemicProblem): Problem class implementing the calculation of the residuals.
  96. plotter (Plotter): Plotter object to plot dataset curves.
  97. state_variables (list, optional): List of the names of state variables. Defaults to [].
  98. parameter_regulator (optional): Function to force the parameters to be in a certain range. Defaults to torch.tanh.
  99. input_size (int, optional): Number of the input nodes of the NN. Defaults to 1.
  100. hidden_size (int, optional): Number of the hidden nodes of the NN. Defaults to 20.
  101. hidden_layers (int, optional): Number of the hidden layers for the NN. Defaults to 7.
  102. activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
  103. """
  104. assert len(state_variables) + \
  105. data.number_groups == output_size, f'The number of groups plus the number of state variable must result in the output size\nGroups:\t{data.number_groups}\nState variables:\t{len(state_variables)}\noutput_size: {output_size}\n'
  106. self.device = torch.device(data.device_name)
  107. self.device_name = data.device_name
  108. self.plotter = plotter
  109. self.model = DINN.NN(output_size,
  110. input_size,
  111. hidden_size,
  112. hidden_layers,
  113. activation_layer,
  114. data.t_init,
  115. data.t_final,
  116. activation_output,
  117. use_glorot_initialization=use_glorot_initialization,
  118. use_t_scaled=data.use_scaled_time)
  119. self.model = self.model.to(self.device)
  120. self.data = data
  121. self.parameter_regulator = parameter_regulator
  122. self.problem = problem
  123. self.problem.def_grad_matrix(output_size)
  124. self.parameters_tilda = {}
  125. for parameter in parameter_list:
  126. self.parameters_tilda.update({parameter: torch.nn.Parameter(
  127. torch.rand(1, requires_grad=True, device=self.device_name))})
  128. # new model has to be configured and then trained
  129. self.__is_configured = False
  130. self.__has_trained = False
  131. self.__state_variables = state_variables
  132. self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
  133. self.frames = []
  134. @property
  135. def number_state_variables(self):
  136. return len(self.__state_variables)
  137. def get_regulated_param(self, parameter_name: str):
  138. """Function to get the searched parameters, forced into a certain range.
  139. Args:
  140. parameter_name (str): Name of the parameter to be returned.
  141. Returns:
  142. torch.Parameter: Regulated parameter object of the search parameter.
  143. """
  144. return self.parameter_regulator(self.parameters_tilda[parameter_name])
  145. def get_parameters_tilda(self):
  146. """Function to get the original value (not forced into any range).
  147. Returns:
  148. torch.Parameter: Parameter object of the search parameter.
  149. """
  150. return list(self.parameters_tilda.values())
  151. def get_regulated_param_list(self):
  152. """Get the list of regulated parameters (forced into a specific range).
  153. Returns:
  154. list: list of regulated parameters
  155. """
  156. return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
  157. def get_output(self, index):
  158. output = self.model(self.data.t_batch)
  159. return output[:, index]
  160. def configure_training(self,
  161. lr: float,
  162. epochs: int,
  163. optimizer_class=Optimizer.ADAM,
  164. scheduler_class=Scheduler.CYCLIC,
  165. scheduler_factor=1,
  166. lambda_obs=1,
  167. lambda_physics=1,
  168. verbose=False):
  169. """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
  170. Args:
  171. lr (float): Learning rate for the optimizer.
  172. epochs (int): Number of epochs the NN is supposed to be trained for.
  173. optimizer_name (str, optional): Name of the optimizer class that is supposed to be used. Defaults to 'Adam'.
  174. scheduler_name (str, optional): Name of the scheduler class that is supposed to be used. Defaults to 'CyclicLR'.
  175. verbose (bool, optional): Controles if the configuration process, is to be verbosed. Defaults to False.
  176. """
  177. parameter_list = list(self.model.parameters()) + \
  178. list(self.parameters_tilda.values())
  179. self.epochs = epochs
  180. self.lambda_obs = lambda_obs
  181. self.lambda_physics = lambda_physics
  182. match optimizer_class:
  183. case Optimizer.ADAM:
  184. self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
  185. case _:
  186. self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
  187. if verbose:
  188. print('---------------------------------')
  189. print(
  190. f' Entered unknown optimizer name: {optimizer_class.name}\n Defaulted to ADAM.')
  191. print('---------------------------------')
  192. optimizer_class = Optimizer.ADAM
  193. match scheduler_class:
  194. case Scheduler.CYCLIC:
  195. self.scheduler = torch.optim.lr_scheduler.CyclicLR(
  196. self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
  197. case Scheduler.CONSTANT:
  198. self.scheduler = torch.optim.lr_scheduler.ConstantLR(
  199. self.optimizer, factor=1, total_iters=4)
  200. case Scheduler.LINEAR:
  201. self.scheduler = torch.optim.lr_scheduler.LinearLR(
  202. self.optimizer, start_factor=lr, total_iters=epochs / scheduler_factor)
  203. case Scheduler.POLYNOMIAL:
  204. self.scheduler = torch.optim.lr_scheduler.PolynomialLR(
  205. self.optimizer, total_iters=epochs / scheduler_factor, power=1.0)
  206. case _:
  207. self.scheduler = torch.optim.lr_scheduler.CyclicLR(
  208. self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
  209. if verbose:
  210. print('---------------------------------')
  211. print(
  212. f' Entered unknown scheduler name: {scheduler_class.name}\n Defaulted to CYCLIC.')
  213. print('---------------------------------')
  214. scheduler_class = Scheduler.CYCLIC
  215. if verbose:
  216. print(
  217. f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_class.name}\nScheduler:\t{scheduler_class.name}\n')
  218. self.__is_configured = True
  219. def train(self,
  220. create_animation=False,
  221. animation_sample_rate=500,
  222. verbose=False,
  223. do_split_training=False,
  224. start_split=10000):
  225. """Training routine for the DINN.
  226. Args:
  227. create_animation (boolean, optional): Decides on wether a prediction animation is supposed to be created during training. Defaults to False.
  228. animation_sample_rate (int, optional): Sample rate of the prediction animation. Only used, when create_animation=True. Defaults to 500.
  229. verbose (bool, optional): Controles if the training process, is to be verbosed. Defaults to False.
  230. """
  231. assert self.__is_configured, 'The model has to be configured before training through the use of self.configure training.'
  232. if verbose:
  233. print(f'torch seed: {torch.seed()}')
  234. # arrays to hold values for plotting
  235. self.losses = np.zeros(self.epochs)
  236. self.obs_losses = np.zeros(self.epochs)
  237. self.physics_losses = np.zeros(self.epochs)
  238. self.parameters = [np.zeros(self.epochs) for _ in self.parameters]
  239. for epoch in range(self.epochs):
  240. # get the prediction and the fitting residuals
  241. prediction = self.model(self.data.t_batch)
  242. residuals = self.problem.residual(
  243. prediction, *self.get_regulated_param_list())
  244. self.optimizer.zero_grad()
  245. # calculate loss from the differential system
  246. loss_physics = 0
  247. for residual in residuals:
  248. loss_physics += torch.mean(torch.square(residual))
  249. loss_physics *= self.lambda_physics
  250. # calculate loss from the dataset
  251. loss_obs = 0
  252. for i, group in enumerate(self.data.group_names):
  253. loss_obs += torch.mean(torch.square(
  254. self.data.get_norm(group) - prediction[:, i]))
  255. loss_obs *= self.lambda_obs
  256. if do_split_training:
  257. if epoch < start_split:
  258. loss = loss_obs
  259. else:
  260. loss = loss_obs + loss_physics
  261. else:
  262. loss = loss_obs + loss_physics
  263. loss.backward()
  264. self.optimizer.step()
  265. self.scheduler.step()
  266. # append values for plotting
  267. self.losses[epoch] = loss.item()
  268. self.obs_losses[epoch] = loss_obs.item()
  269. self.physics_losses[epoch] = loss_physics.item()
  270. for i, parameter in enumerate(self.parameters_tilda.items()):
  271. self.parameters[i][epoch] = self.get_regulated_param(
  272. parameter[0]).item()
  273. # do snapshot for prediction animation
  274. if epoch % animation_sample_rate == 0 and create_animation:
  275. # prediction
  276. prediction = self.model(self.data.t_batch)
  277. t = torch.arange(
  278. 0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
  279. groups = self.data.get_denormalized_data(
  280. [prediction[:, i] for i in range(self.data.number_groups)])
  281. plot_labels = ['I_pred', 'I_true']
  282. background_list = [0, 1]
  283. self.plotter.plot(t,
  284. [groups[1]] + [self.data.data[1]],
  285. plot_labels,
  286. 'Frame',
  287. f'epoch {epoch}',
  288. figure_shape=(12, 6),
  289. is_frame=True,
  290. is_background=background_list,
  291. lw=3,
  292. legend_loc='upper right',
  293. xlabel='time / days',
  294. ylabel='amount of people')
  295. # print training advancements
  296. if epoch % 1000 == 0 and verbose:
  297. print(
  298. f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
  299. print(f'physics loss:\t\t{loss_physics.item()}')
  300. print(f'observation loss:\t{loss_obs.item()}')
  301. print(f'loss:\t\t\t{loss.item()}')
  302. print('---------------------------------')
  303. if len(self.parameters_tilda.items()) != 0:
  304. for parameter in self.parameters_tilda.items():
  305. print(
  306. f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
  307. print('#################################')
  308. # create prediction animation
  309. if create_animation:
  310. self.plotter.animate(self.data.name + '_animation')
  311. self.plotter.reset_animation()
  312. self.__has_trained = True
  313. def plot_training_graphs(self, ground_truth=[]):
  314. """Plot the loss graph and the graphs of the advancements of the parameters.
  315. Args:
  316. ground_truth (list): List of the ground truth parameters
  317. """
  318. assert self.__has_trained, 'Model has to be trained, before plotting the training graphs'
  319. epochs = np.arange(0, self.epochs, 1)
  320. # plot loss
  321. self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss',
  322. 'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
  323. # plot parameters
  324. for i, parameter in enumerate(self.parameters):
  325. if len(ground_truth) > i:
  326. self.plotter.plot(epochs,
  327. [parameter,
  328. np.ones_like(epochs) * ground_truth[i]],
  329. ['prediction', 'ground truth'],
  330. self.data.name + '_' +
  331. list(self.parameters_tilda.items())[i][0],
  332. list(self.parameters_tilda.items())[i][0],
  333. (6, 6),
  334. is_background=[0, 1],
  335. xlabel='epochs')
  336. else:
  337. self.plotter.plot(epochs,
  338. [parameter],
  339. ['prediction'],
  340. self.data.name + '_' +
  341. list(self.parameters_tilda.items())[i][0],
  342. list(self.parameters_tilda.items())[
  343. i][0], (6, 6),
  344. xlabel='epochs',
  345. plot_legend=False)
  346. def save_training_process(self, title, save_predictions=True):
  347. losses = {'loss': self.losses,
  348. 'obs_loss': self.obs_losses,
  349. 'physics_loss': self.physics_losses}
  350. for loss in losses.keys():
  351. with open(f'./results/training_metrics/{title}_{loss}.csv', 'w', newline='') as csvfile:
  352. writer = csv.writer(csvfile, delimiter=',')
  353. writer.writerow(losses[loss])
  354. for i, parameter in enumerate(self.parameters):
  355. with open(f'./results/training_metrics/{title}_{list(self.parameters_tilda.items())[i][0]}.csv', 'w', newline='') as csvfile:
  356. writer = csv.writer(csvfile, delimiter=',')
  357. writer.writerow(parameter)
  358. if save_predictions:
  359. prediction = self.model(self.data.t_batch)
  360. for i, group in enumerate(self.data.group_names):
  361. t = torch.linspace(
  362. 0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0]).detach().cpu().numpy()
  363. true = self.data.get_group(group).detach().cpu().numpy()
  364. pred = self.data.get_denormalized_data([prediction[:, i]])[
  365. 0].detach().cpu().numpy()
  366. print(t.shape, true.shape)
  367. with open(f'./results/I_predictions/{title}_I_prediction.csv', 'w', newline='') as csvfile:
  368. writer = csv.writer(csvfile, delimiter=',')
  369. writer.writerow(t)
  370. writer.writerow(true)
  371. writer.writerow(pred)
  372. def plot_state_variables(self):
  373. prediction = self.model(self.data.t_batch)
  374. for i in range(self.data.number_groups, self.data.number_groups + self.number_state_variables):
  375. t = torch.linspace(
  376. 0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
  377. self.plotter.plot(t,
  378. [prediction[:, i]],
  379. [self.__state_variables[i - self.data.number_groups]],
  380. f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
  381. self.__state_variables[i -
  382. self.data.number_groups],
  383. figure_shape=(12, 6),
  384. plot_legend=True,
  385. xlabel='time / days')