|
@@ -1,4 +1,10 @@
|
|
|
import torch
|
|
|
+from enum import Enum
|
|
|
+
|
|
|
+class Norms(Enum):
|
|
|
+ POPULATION=0
|
|
|
+ MIN_MAX=1
|
|
|
+ CONSTANT=2
|
|
|
|
|
|
class PandemicDataset:
|
|
|
def __init__(self,
|
|
@@ -6,7 +12,10 @@ class PandemicDataset:
|
|
|
group_names:list,
|
|
|
N: int,
|
|
|
t,
|
|
|
- *groups):
|
|
|
+ *groups,
|
|
|
+ norm_name=Norms.MIN_MAX,
|
|
|
+ C = 10**5,
|
|
|
+ use_scaled_time=False):
|
|
|
"""Class to hold all data for one training process.
|
|
|
|
|
|
Args:
|
|
@@ -15,19 +24,39 @@ class PandemicDataset:
|
|
|
t (np.array): Array of timesteps.
|
|
|
*groups (np.array): Arrays of size data for each group for each timestep.
|
|
|
"""
|
|
|
-
|
|
|
if torch.cuda.is_available():
|
|
|
self.device_name = 'cuda'
|
|
|
else:
|
|
|
self.device_name = 'cpu'
|
|
|
|
|
|
+ match norm_name:
|
|
|
+ case Norms.POPULATION:
|
|
|
+ self.__norm = self.__population_norm
|
|
|
+ self.__denorm = self.__population_denorm
|
|
|
+ case Norms.MIN_MAX:
|
|
|
+ self.__norm = self.__min_max_norm
|
|
|
+ self.__denorm = self.__min_max_denorm
|
|
|
+ case Norms.CONSTANT:
|
|
|
+ self.__norm = self.__constant_norm
|
|
|
+ self.__denorm = self.__constant_denorm
|
|
|
+ case _:
|
|
|
+ self.__norm = self.__min_max_norm
|
|
|
+ self.__denorm = self.__min_max_denorm
|
|
|
+
|
|
|
self.name = name
|
|
|
self.N = N
|
|
|
self.t_init = t.min()
|
|
|
self.t_final = t.max()
|
|
|
+ self.C = C
|
|
|
|
|
|
self.t_raw = torch.tensor(t, requires_grad=True, device=self.device_name)
|
|
|
- self.t_batch = self.t_raw.view(-1, 1).float()
|
|
|
+
|
|
|
+ self.t_scaled = ((self.t_raw - self.t_init) / (self.t_final - self.t_init)).detach().requires_grad_()
|
|
|
+ self.use_scaled_time = use_scaled_time
|
|
|
+ if use_scaled_time:
|
|
|
+ self.t_batch = self.t_scaled.view(-1, 1).float()
|
|
|
+ else:
|
|
|
+ self.t_batch = self.t_raw.view(-1, 1).float()
|
|
|
|
|
|
self.__group_dict = {}
|
|
|
for i, name in enumerate(group_names):
|
|
@@ -39,7 +68,7 @@ class PandemicDataset:
|
|
|
|
|
|
self.__mins = [torch.min(group) for group in self.__groups]
|
|
|
self.__maxs = [torch.max(group) for group in self.__groups]
|
|
|
- self.__norms = [(self.__groups[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(len(groups))]
|
|
|
+ self.__norms = self.__norm(self.__groups)
|
|
|
|
|
|
@property
|
|
|
def number_groups(self):
|
|
@@ -52,14 +81,32 @@ class PandemicDataset:
|
|
|
@property
|
|
|
def group_names(self):
|
|
|
return self.__group_names
|
|
|
+
|
|
|
+ def __population_norm(self, data):
|
|
|
+ return [(data[i] / self.N) for i in range(self.number_groups)]
|
|
|
+
|
|
|
+ def __population_denorm(self, data):
|
|
|
+ return [(data[i] * self.N) for i in range(self.number_groups)]
|
|
|
+
|
|
|
+ def __min_max_norm(self, data):
|
|
|
+ return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
|
|
|
+
|
|
|
+ def __min_max_denorm(self, data):
|
|
|
+ return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * data[i]) for i in range(self.number_groups)]
|
|
|
+
|
|
|
+ def __constant_norm(self, data):
|
|
|
+ return [(data[i] / self.C) for i in range(self.number_groups)]
|
|
|
+
|
|
|
+ def __constant_denorm(self, data):
|
|
|
+ return [(data[i] * self.C) for i in range(self.number_groups)]
|
|
|
|
|
|
def get_normalized_data(self, data:list):
|
|
|
assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
|
|
|
- return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
|
|
|
+ return self.__norm(data)
|
|
|
|
|
|
- def get_denormalized_data(self, normalized_data:list):
|
|
|
- assert len(normalized_data) == self.number_groups, f'normalized_data parameter needs same length as there are groups in the dataset ({self.number_groups})'
|
|
|
- return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * normalized_data[i]) for i in range(self.number_groups)]
|
|
|
+ def get_denormalized_data(self, data:list):
|
|
|
+ assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
|
|
|
+ return self.__denorm(data)
|
|
|
|
|
|
def get_group(self, name:str):
|
|
|
return self.__groups[self.__group_dict[name]]
|