浏览代码

adding generalized dataset class

phillip.rothenbeck 1 年之前
父节点
当前提交
f30c476b5e
共有 1 个文件被更改,包括 52 次插入0 次删除
  1. 52 0
      src/dataset.py

+ 52 - 0
src/dataset.py

@@ -0,0 +1,52 @@
+import torch
+
+class PandemicDataset:
+    def __init__(self, 
+                 group_names:list, 
+                 N: int, 
+                 t, 
+                 *groups):
+        """Class to hold all data for one training process.
+
+        Args:
+            group_names (list): Names of the groups, in which the population is split.
+            N (int): Size of the population.
+            t (np.array): Array of timesteps.
+            *groups (np.array): Arrays of size data for each group for each timestep..
+        """
+        self.N = N
+        self.t_raw = torch.tensor(t, requires_grad=True)
+        self.t_batch = self.t_raw.view(-1, 1).float()
+
+        self.__group_dict = {}
+        for i, name in enumerate(group_names):
+            self.__group_dict.update({name : i})
+
+        self.__group_names = group_names
+
+        self.__groups = [torch.tensor(group) for group in groups]
+        
+        self.__mins = [torch.min(group) for group in self.__groups]
+        self.__maxs = [torch.max(group) for group in self.__groups]
+        self.__norms = [(self.__groups[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(len(groups))]
+
+        self.number_groups = len(groups)
+
+    def get_group(self, name:str):
+        return self.__groups[self.__group_dict[name]]
+    
+    def get_min(self, name:str):
+        return self.__mins[self.__group_dict[name]]
+    
+    def get_max(self, name:str):
+        return self.__maxs[self.__group_dict[name]]
+    
+    def get_norm(self, name:str):
+        return self.__norms[self.__group_dict[name]]
+    
+    def get_group_names(self):
+        return self.__group_names
+    
+    def to_device(self, device):
+        self.t_raw = self.t_raw.to(device).detach().requires_grad_()
+        self.t_batch = self.t_batch.to(device).detach().requires_grad_()