|
|
@@ -0,0 +1,52 @@
|
|
|
+import torch
|
|
|
+
|
|
|
+class PandemicDataset:
|
|
|
+ def __init__(self,
|
|
|
+ group_names:list,
|
|
|
+ N: int,
|
|
|
+ t,
|
|
|
+ *groups):
|
|
|
+ """Class to hold all data for one training process.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ group_names (list): Names of the groups, in which the population is split.
|
|
|
+ N (int): Size of the population.
|
|
|
+ t (np.array): Array of timesteps.
|
|
|
+ *groups (np.array): Arrays of size data for each group for each timestep..
|
|
|
+ """
|
|
|
+ self.N = N
|
|
|
+ self.t_raw = torch.tensor(t, requires_grad=True)
|
|
|
+ self.t_batch = self.t_raw.view(-1, 1).float()
|
|
|
+
|
|
|
+ self.__group_dict = {}
|
|
|
+ for i, name in enumerate(group_names):
|
|
|
+ self.__group_dict.update({name : i})
|
|
|
+
|
|
|
+ self.__group_names = group_names
|
|
|
+
|
|
|
+ self.__groups = [torch.tensor(group) for group in groups]
|
|
|
+
|
|
|
+ self.__mins = [torch.min(group) for group in self.__groups]
|
|
|
+ self.__maxs = [torch.max(group) for group in self.__groups]
|
|
|
+ self.__norms = [(self.__groups[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(len(groups))]
|
|
|
+
|
|
|
+ self.number_groups = len(groups)
|
|
|
+
|
|
|
+ def get_group(self, name:str):
|
|
|
+ return self.__groups[self.__group_dict[name]]
|
|
|
+
|
|
|
+ def get_min(self, name:str):
|
|
|
+ return self.__mins[self.__group_dict[name]]
|
|
|
+
|
|
|
+ def get_max(self, name:str):
|
|
|
+ return self.__maxs[self.__group_dict[name]]
|
|
|
+
|
|
|
+ def get_norm(self, name:str):
|
|
|
+ return self.__norms[self.__group_dict[name]]
|
|
|
+
|
|
|
+ def get_group_names(self):
|
|
|
+ return self.__group_names
|
|
|
+
|
|
|
+ def to_device(self, device):
|
|
|
+ self.t_raw = self.t_raw.to(device).detach().requires_grad_()
|
|
|
+ self.t_batch = self.t_batch.to(device).detach().requires_grad_()
|