Prechádzať zdrojové kódy

define split t into raw and batch

phillip.rothenbeck 1 rok pred
rodič
commit
1dfdc3934a
1 zmenil súbory, kde vykonal 2 pridanie a 2 odobranie
  1. 2 2
      src/sir_dinn/dataset/dataset.py

+ 2 - 2
src/sir_dinn/dataset/dataset.py

@@ -3,8 +3,8 @@ import torch
 class SIR_Dataset:
     def __init__(self, N, t, S, I, R):
         self.N = N
-        self.t = torch.tensor(t, requires_grad=True).view(-1, 1).float()
-        print(torch.min(self.t), torch.max(self.t))
+        self.t_raw = torch.tensor(t, requires_grad=True)
+        self.t_batch = self.t_raw.view(-1, 1).float()
 
         self.S = torch.tensor(S)
         self.I = torch.tensor(I)