phillip.rothenbeck 1 рік тому
батько
коміт
ff49c80d0c
1 змінених файлів з 12 додано та 5 видалено
  1. 12 5
      src/dataset.py

+ 12 - 5
src/dataset.py

@@ -2,6 +2,7 @@ import torch
 
 class PandemicDataset:
     def __init__(self, 
+                 name:str,
                  group_names:list, 
                  N: int, 
                  t, 
@@ -14,8 +15,14 @@ class PandemicDataset:
             t (np.array): Array of timesteps.
             *groups (np.array): Arrays of size data for each group for each timestep..
         """
+
+        if torch.cuda.is_available():
+            self.device_name = 'cuda'
+        else:
+            self.device_name = 'cpu'
+        self.name = name
         self.N = N
-        self.t_raw = torch.tensor(t, requires_grad=True)
+        self.t_raw = torch.tensor(t, requires_grad=True, device=self.device_name)
         self.t_batch = self.t_raw.view(-1, 1).float()
 
         self.__group_dict = {}
@@ -24,7 +31,7 @@ class PandemicDataset:
 
         self.__group_names = group_names
 
-        self.__groups = [torch.tensor(group) for group in groups]
+        self.__groups = [torch.tensor(group, device=self.device_name) for group in groups]
         
         self.__mins = [torch.min(group) for group in self.__groups]
         self.__maxs = [torch.max(group) for group in self.__groups]
@@ -32,6 +39,9 @@ class PandemicDataset:
 
         self.number_groups = len(groups)
 
+    def get_data(self):
+        return self.__groups
+
     def get_group(self, name:str):
         return self.__groups[self.__group_dict[name]]
     
@@ -47,6 +57,3 @@ class PandemicDataset:
     def get_group_names(self):
         return self.__group_names
     
-    def to_device(self, device):
-        self.t_raw = self.t_raw.to(device).detach().requires_grad_()
-        self.t_batch = self.t_batch.to(device).detach().requires_grad_()