Pārlūkot izejas kodu

add norms and scaling

phillip.rothenbeck 4 mēneši atpakaļ
vecāks
revīzija
b724dac3ad
1 mainītis faili ar 55 papildinājumiem un 8 dzēšanām
  1. 55 8
      src/dataset.py

+ 55 - 8
src/dataset.py

@@ -1,4 +1,10 @@
 import torch
+from enum import Enum
+
+class Norms(Enum):
+    POPULATION=0
+    MIN_MAX=1
+    CONSTANT=2
 
 class PandemicDataset:
     def __init__(self, 
@@ -6,7 +12,10 @@ class PandemicDataset:
                  group_names:list, 
                  N: int, 
                  t, 
-                 *groups):
+                 *groups, 
+                 norm_name=Norms.MIN_MAX,
+                 C = 10**5,
+                 use_scaled_time=False):
         """Class to hold all data for one training process.
 
         Args:
@@ -15,19 +24,39 @@ class PandemicDataset:
             t (np.array): Array of timesteps.
             *groups (np.array): Arrays of size data for each group for each timestep.
         """
-
         if torch.cuda.is_available():
             self.device_name = 'cuda'
         else:
             self.device_name = 'cpu'
 
+        match norm_name:
+            case Norms.POPULATION:
+                self.__norm = self.__population_norm
+                self.__denorm = self.__population_denorm
+            case Norms.MIN_MAX:
+                self.__norm = self.__min_max_norm
+                self.__denorm = self.__min_max_denorm
+            case Norms.CONSTANT:
+                self.__norm = self.__constant_norm
+                self.__denorm = self.__constant_denorm
+            case _:
+                self.__norm = self.__min_max_norm
+                self.__denorm = self.__min_max_denorm
+
         self.name = name
         self.N = N
         self.t_init = t.min()
         self.t_final = t.max()
+        self.C = C
 
         self.t_raw = torch.tensor(t, requires_grad=True, device=self.device_name)
-        self.t_batch = self.t_raw.view(-1, 1).float()
+
+        self.t_scaled = ((self.t_raw - self.t_init) / (self.t_final - self.t_init)).detach().requires_grad_()
+        self.use_scaled_time = use_scaled_time
+        if use_scaled_time:
+            self.t_batch = self.t_scaled.view(-1, 1).float()
+        else:
+            self.t_batch = self.t_raw.view(-1, 1).float()
 
         self.__group_dict = {}
         for i, name in enumerate(group_names):
@@ -39,7 +68,7 @@ class PandemicDataset:
         
         self.__mins = [torch.min(group) for group in self.__groups]
         self.__maxs = [torch.max(group) for group in self.__groups]
-        self.__norms = [(self.__groups[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(len(groups))]
+        self.__norms = self.__norm(self.__groups)
 
     @property
     def number_groups(self):
@@ -52,14 +81,32 @@ class PandemicDataset:
     @property
     def group_names(self):
         return self.__group_names
+    
+    def __population_norm(self, data):
+        return [(data[i] / self.N) for i in range(self.number_groups)]
+    
+    def __population_denorm(self, data):
+        return [(data[i] * self.N) for i in range(self.number_groups)]
+
+    def __min_max_norm(self, data):
+        return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
+    
+    def __min_max_denorm(self, data):
+        return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * data[i]) for i in range(self.number_groups)]
+    
+    def __constant_norm(self, data):
+        return [(data[i] / self.C) for i in range(self.number_groups)]
+
+    def __constant_denorm(self, data):
+        return [(data[i] * self.C) for i in range(self.number_groups)]
 
     def get_normalized_data(self, data:list):
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
-        return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
+        return self.__norm(data)
     
-    def get_denormalized_data(self, normalized_data:list):
-        assert len(normalized_data) == self.number_groups, f'normalized_data parameter needs same length as there are groups in the dataset ({self.number_groups})'
-        return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * normalized_data[i]) for i in range(self.number_groups)]
+    def get_denormalized_data(self, data:list):
+        assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
+        return self.__denorm(data)
 
     def get_group(self, name:str):
         return self.__groups[self.__group_dict[name]]