|
@@ -2,6 +2,7 @@ import torch
|
|
|
|
|
|
|
|
class PandemicDataset:
|
|
class PandemicDataset:
|
|
|
def __init__(self,
|
|
def __init__(self,
|
|
|
|
|
+ name:str,
|
|
|
group_names:list,
|
|
group_names:list,
|
|
|
N: int,
|
|
N: int,
|
|
|
t,
|
|
t,
|
|
@@ -14,8 +15,14 @@ class PandemicDataset:
|
|
|
t (np.array): Array of timesteps.
|
|
t (np.array): Array of timesteps.
|
|
|
*groups (np.array): Arrays of size data for each group for each timestep..
|
|
*groups (np.array): Arrays of size data for each group for each timestep..
|
|
|
"""
|
|
"""
|
|
|
|
|
+
|
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
|
+ self.device_name = 'cuda'
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.device_name = 'cpu'
|
|
|
|
|
+ self.name = name
|
|
|
self.N = N
|
|
self.N = N
|
|
|
- self.t_raw = torch.tensor(t, requires_grad=True)
|
|
|
|
|
|
|
+ self.t_raw = torch.tensor(t, requires_grad=True, device=self.device_name)
|
|
|
self.t_batch = self.t_raw.view(-1, 1).float()
|
|
self.t_batch = self.t_raw.view(-1, 1).float()
|
|
|
|
|
|
|
|
self.__group_dict = {}
|
|
self.__group_dict = {}
|
|
@@ -24,7 +31,7 @@ class PandemicDataset:
|
|
|
|
|
|
|
|
self.__group_names = group_names
|
|
self.__group_names = group_names
|
|
|
|
|
|
|
|
- self.__groups = [torch.tensor(group) for group in groups]
|
|
|
|
|
|
|
+ self.__groups = [torch.tensor(group, device=self.device_name) for group in groups]
|
|
|
|
|
|
|
|
self.__mins = [torch.min(group) for group in self.__groups]
|
|
self.__mins = [torch.min(group) for group in self.__groups]
|
|
|
self.__maxs = [torch.max(group) for group in self.__groups]
|
|
self.__maxs = [torch.max(group) for group in self.__groups]
|
|
@@ -32,6 +39,9 @@ class PandemicDataset:
|
|
|
|
|
|
|
|
self.number_groups = len(groups)
|
|
self.number_groups = len(groups)
|
|
|
|
|
|
|
|
|
|
+ def get_data(self):
|
|
|
|
|
+ return self.__groups
|
|
|
|
|
+
|
|
|
def get_group(self, name:str):
|
|
def get_group(self, name:str):
|
|
|
return self.__groups[self.__group_dict[name]]
|
|
return self.__groups[self.__group_dict[name]]
|
|
|
|
|
|
|
@@ -47,6 +57,3 @@ class PandemicDataset:
|
|
|
def get_group_names(self):
|
|
def get_group_names(self):
|
|
|
return self.__group_names
|
|
return self.__group_names
|
|
|
|
|
|
|
|
- def to_device(self, device):
|
|
|
|
|
- self.t_raw = self.t_raw.to(device).detach().requires_grad_()
|
|
|
|
|
- self.t_batch = self.t_batch.to(device).detach().requires_grad_()
|
|
|