|
|
@@ -3,8 +3,8 @@ import torch
|
|
|
class SIR_Dataset:
|
|
|
def __init__(self, N, t, S, I, R):
|
|
|
self.N = N
|
|
|
- self.t = torch.tensor(t, requires_grad=True).view(-1, 1).float()
|
|
|
- print(torch.min(self.t), torch.max(self.t))
|
|
|
+ self.t_raw = torch.tensor(t, requires_grad=True)
|
|
|
+ self.t_batch = self.t_raw.view(-1, 1).float()
|
|
|
|
|
|
self.S = torch.tensor(S)
|
|
|
self.I = torch.tensor(I)
|