@@ -3,8 +3,8 @@ import torch
class SIR_Dataset:
def __init__(self, N, t, S, I, R):
self.N = N
- self.t = torch.tensor(t, requires_grad=True).view(-1, 1).float()
- print(torch.min(self.t), torch.max(self.t))
+ self.t_raw = torch.tensor(t, requires_grad=True)
+ self.t_batch = self.t_raw.view(-1, 1).float()
self.S = torch.tensor(S)
self.I = torch.tensor(I)