|
|
@@ -20,16 +20,14 @@ class PandemicDataset:
|
|
|
self.device_name = 'cuda'
|
|
|
else:
|
|
|
self.device_name = 'cpu'
|
|
|
-
|
|
|
+
|
|
|
self.name = name
|
|
|
self.N = N
|
|
|
self.t_init = t.min()
|
|
|
self.t_final = t.max()
|
|
|
|
|
|
self.t_raw = torch.tensor(t, requires_grad=True, device=self.device_name)
|
|
|
- self.t_norm = torch.tensor((t - self.t_init) / (self.t_final - self.t_init), requires_grad=True, device=self.device_name)
|
|
|
-
|
|
|
- self.t_batch = self.t_norm.view(-1, 1).float()
|
|
|
+ self.t_batch = self.t_raw.view(-1, 1).float()
|
|
|
|
|
|
self.__group_dict = {}
|
|
|
for i, name in enumerate(group_names):
|
|
|
@@ -54,10 +52,6 @@ class PandemicDataset:
|
|
|
@property
|
|
|
def group_names(self):
|
|
|
return self.__group_names
|
|
|
-
|
|
|
- @property
|
|
|
- def normalization_differantial(self):
|
|
|
- return 1 / (self.t_final - self.t_init)
|
|
|
|
|
|
def get_normalized_data(self, data:list):
|
|
|
assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
|