test_high_level_generators.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. import unittest
  2. import os
  3. from os.path import expanduser
  4. import torch
  5. from torchvision.datasets import CIFAR10, MNIST
  6. from torchvision.datasets.utils import download_url, extract_archive
  7. from torchvision.transforms import ToTensor
  8. from avalanche.benchmarks import dataset_benchmark, filelist_benchmark, \
  9. tensors_benchmark, paths_benchmark, data_incremental_benchmark, \
  10. benchmark_with_validation_stream
  11. from avalanche.benchmarks.scenarios.generic_benchmark_creation import \
  12. create_lazy_generic_benchmark, LazyStreamDefinition
  13. from avalanche.benchmarks.utils import AvalancheDataset, \
  14. AvalancheTensorDataset, AvalancheDatasetType
  15. from tests.unit_tests_utils import common_setups
  16. class HighLevelGeneratorTests(unittest.TestCase):
  17. def setUp(self):
  18. common_setups()
  19. def test_dataset_benchmark(self):
  20. train_MNIST = MNIST(
  21. root=expanduser("~") + "/.avalanche/data/mnist/",
  22. train=True, download=True
  23. )
  24. test_MNIST = MNIST(
  25. root=expanduser("~") + "/.avalanche/data/mnist/",
  26. train=False, download=True
  27. )
  28. train_cifar10 = CIFAR10(
  29. root=expanduser("~") + "/.avalanche/data/cifar10/",
  30. train=True, download=True
  31. )
  32. test_cifar10 = CIFAR10(
  33. root=expanduser("~") + "/.avalanche/data/cifar10/",
  34. train=False, download=True
  35. )
  36. generic_benchmark = dataset_benchmark(
  37. [train_MNIST, train_cifar10],
  38. [test_MNIST, test_cifar10])
  39. def test_dataset_benchmark_avalanche_dataset(self):
  40. train_MNIST = AvalancheDataset(MNIST(
  41. root=expanduser("~") + "/.avalanche/data/mnist/",
  42. train=True, download=True
  43. ), task_labels=0)
  44. test_MNIST = AvalancheDataset(MNIST(
  45. root=expanduser("~") + "/.avalanche/data/mnist/",
  46. train=False, download=True
  47. ), task_labels=0)
  48. train_cifar10 = AvalancheDataset(CIFAR10(
  49. root=expanduser("~") + "/.avalanche/data/cifar10/",
  50. train=True, download=True
  51. ), task_labels=1)
  52. test_cifar10 = AvalancheDataset(CIFAR10(
  53. root=expanduser("~") + "/.avalanche/data/cifar10/",
  54. train=False, download=True
  55. ), task_labels=1)
  56. generic_benchmark = dataset_benchmark(
  57. [train_MNIST, train_cifar10],
  58. [test_MNIST, test_cifar10])
  59. self.assertEqual(0, generic_benchmark.train_stream[0].task_label)
  60. self.assertEqual(1, generic_benchmark.train_stream[1].task_label)
  61. self.assertEqual(0, generic_benchmark.test_stream[0].task_label)
  62. self.assertEqual(1, generic_benchmark.test_stream[1].task_label)
  63. def test_filelist_benchmark(self):
  64. download_url(
  65. 'https://storage.googleapis.com/mledu-datasets/'
  66. 'cats_and_dogs_filtered.zip', expanduser("~") + "/.avalanche/data",
  67. 'cats_and_dogs_filtered.zip')
  68. archive_name = os.path.join(
  69. expanduser("~") + "/.avalanche/data", 'cats_and_dogs_filtered.zip')
  70. extract_archive(archive_name,
  71. to_path=expanduser("~") + "/.avalanche/data/")
  72. dirpath = expanduser("~") + \
  73. "/.avalanche/data/cats_and_dogs_filtered/train"
  74. for filelist, dir, label in zip(
  75. ["train_filelist_00.txt", "train_filelist_01.txt"],
  76. ["cats", "dogs"],
  77. [0, 1]):
  78. # First, obtain the list of files
  79. filenames_list = os.listdir(os.path.join(dirpath, dir))
  80. with open(filelist, "w") as wf:
  81. for name in filenames_list:
  82. wf.write(
  83. "{} {}\n".format(os.path.join(dir, name), label)
  84. )
  85. generic_benchmark = filelist_benchmark(
  86. dirpath,
  87. ["train_filelist_00.txt", "train_filelist_01.txt"],
  88. ["train_filelist_00.txt"],
  89. task_labels=[0, 0],
  90. complete_test_set_only=True,
  91. train_transform=ToTensor(),
  92. eval_transform=ToTensor()
  93. )
  94. self.assertEqual(2, len(generic_benchmark.train_stream))
  95. self.assertEqual(1, len(generic_benchmark.test_stream))
  96. def test_paths_benchmark(self):
  97. download_url(
  98. 'https://storage.googleapis.com/mledu-datasets/'
  99. 'cats_and_dogs_filtered.zip', expanduser("~") + "/.avalanche/data",
  100. 'cats_and_dogs_filtered.zip')
  101. archive_name = os.path.join(
  102. expanduser("~") + "/.avalanche/data", 'cats_and_dogs_filtered.zip')
  103. extract_archive(archive_name,
  104. to_path=expanduser("~") + "/.avalanche/data/")
  105. dirpath = expanduser("~") + \
  106. "/.avalanche/data/cats_and_dogs_filtered/train"
  107. train_experiences = []
  108. for rel_dir, label in zip(
  109. ["cats", "dogs"],
  110. [0, 1]):
  111. filenames_list = os.listdir(os.path.join(dirpath, rel_dir))
  112. experience_paths = []
  113. for name in filenames_list:
  114. instance_tuple = (os.path.join(dirpath, rel_dir, name), label)
  115. experience_paths.append(instance_tuple)
  116. train_experiences.append(experience_paths)
  117. generic_benchmark = paths_benchmark(
  118. train_experiences,
  119. [train_experiences[0]], # Single test set
  120. task_labels=[0, 0],
  121. complete_test_set_only=True,
  122. train_transform=ToTensor(),
  123. eval_transform=ToTensor()
  124. )
  125. self.assertEqual(2, len(generic_benchmark.train_stream))
  126. self.assertEqual(1, len(generic_benchmark.test_stream))
  127. def test_tensors_benchmark(self):
  128. pattern_shape = (3, 32, 32)
  129. # Definition of training experiences
  130. # Experience 1
  131. experience_1_x = torch.zeros(100, *pattern_shape)
  132. experience_1_y = torch.zeros(100, dtype=torch.long)
  133. # Experience 2
  134. experience_2_x = torch.zeros(80, *pattern_shape)
  135. experience_2_y = torch.ones(80, dtype=torch.long)
  136. # Test experience
  137. test_x = torch.zeros(50, *pattern_shape)
  138. test_y = torch.zeros(50, dtype=torch.long)
  139. generic_benchmark = tensors_benchmark(
  140. train_tensors=[(experience_1_x, experience_1_y),
  141. (experience_2_x, experience_2_y)],
  142. test_tensors=[(test_x, test_y)],
  143. task_labels=[0, 0], # Task label of each train exp
  144. complete_test_set_only=True
  145. )
  146. self.assertEqual(2, len(generic_benchmark.train_stream))
  147. self.assertEqual(1, len(generic_benchmark.test_stream))
  148. def test_data_incremental_benchmark(self):
  149. pattern_shape = (3, 32, 32)
  150. # Definition of training experiences
  151. # Experience 1
  152. experience_1_x = torch.zeros(100, *pattern_shape)
  153. experience_1_y = torch.zeros(100, dtype=torch.long)
  154. # Experience 2
  155. experience_2_x = torch.zeros(80, *pattern_shape)
  156. experience_2_y = torch.ones(80, dtype=torch.long)
  157. # Test experience
  158. test_x = torch.zeros(50, *pattern_shape)
  159. test_y = torch.zeros(50, dtype=torch.long)
  160. initial_benchmark_instance = tensors_benchmark(
  161. train_tensors=[(experience_1_x, experience_1_y),
  162. (experience_2_x, experience_2_y)],
  163. test_tensors=[(test_x, test_y)],
  164. task_labels=[0, 0], # Task label of each train exp
  165. complete_test_set_only=True)
  166. data_incremental_instance = data_incremental_benchmark(
  167. initial_benchmark_instance, 12, shuffle=False, drop_last=False)
  168. self.assertEqual(16, len(data_incremental_instance.train_stream))
  169. self.assertEqual(1, len(data_incremental_instance.test_stream))
  170. self.assertTrue(data_incremental_instance.complete_test_set_only)
  171. tensor_idx = 0
  172. ref_tensor_x = experience_1_x
  173. ref_tensor_y = experience_1_y
  174. for exp in data_incremental_instance.train_stream:
  175. if exp.current_experience == 8:
  176. # Last mini-exp from 1st exp
  177. self.assertEqual(4, len(exp.dataset))
  178. elif exp.current_experience == 15:
  179. # Last mini-exp from 2nd exp
  180. self.assertEqual(8, len(exp.dataset))
  181. else:
  182. # Other mini-exp
  183. self.assertEqual(12, len(exp.dataset))
  184. if tensor_idx >= 100:
  185. ref_tensor_x = experience_2_x
  186. ref_tensor_y = experience_2_y
  187. tensor_idx = 0
  188. for x, y, *_ in exp.dataset:
  189. self.assertTrue(torch.equal(ref_tensor_x[tensor_idx], x))
  190. self.assertTrue(torch.equal(ref_tensor_y[tensor_idx], y))
  191. tensor_idx += 1
  192. exp = data_incremental_instance.test_stream[0]
  193. self.assertEqual(50, len(exp.dataset))
  194. tensor_idx = 0
  195. for x, y, *_ in exp.dataset:
  196. self.assertTrue(torch.equal(test_x[tensor_idx], x))
  197. self.assertTrue(torch.equal(test_y[tensor_idx], y))
  198. tensor_idx += 1
  199. def test_data_incremental_benchmark_from_lazy_benchmark(self):
  200. pattern_shape = (3, 32, 32)
  201. # Definition of training experiences
  202. # Experience 1
  203. experience_1_x = torch.zeros(100, *pattern_shape)
  204. experience_1_y = torch.zeros(100, dtype=torch.long)
  205. experience_1_dataset = AvalancheTensorDataset(
  206. experience_1_x, experience_1_y)
  207. # Experience 2
  208. experience_2_x = torch.zeros(80, *pattern_shape)
  209. experience_2_y = torch.ones(80, dtype=torch.long)
  210. experience_2_dataset = AvalancheTensorDataset(
  211. experience_2_x, experience_2_y)
  212. # Test experience
  213. test_x = torch.zeros(50, *pattern_shape)
  214. test_y = torch.zeros(50, dtype=torch.long)
  215. experience_test = AvalancheTensorDataset(
  216. test_x, test_y)
  217. def train_gen():
  218. # Lazy generator of the training stream
  219. for dataset in [experience_1_dataset, experience_2_dataset]:
  220. yield dataset
  221. def test_gen():
  222. # Lazy generator of the test stream
  223. for dataset in [experience_test]:
  224. yield dataset
  225. initial_benchmark_instance = create_lazy_generic_benchmark(
  226. train_generator=LazyStreamDefinition(train_gen(), 2, [0, 0]),
  227. test_generator=LazyStreamDefinition(test_gen(), 1, [0]),
  228. complete_test_set_only=True,
  229. dataset_type=AvalancheDatasetType.CLASSIFICATION)
  230. data_incremental_instance = data_incremental_benchmark(
  231. initial_benchmark_instance, 12, shuffle=False, drop_last=False)
  232. self.assertEqual(16, len(data_incremental_instance.train_stream))
  233. self.assertEqual(1, len(data_incremental_instance.test_stream))
  234. self.assertTrue(data_incremental_instance.complete_test_set_only)
  235. tensor_idx = 0
  236. ref_tensor_x = experience_1_x
  237. ref_tensor_y = experience_1_y
  238. for exp in data_incremental_instance.train_stream:
  239. if exp.current_experience == 8:
  240. # Last mini-exp from 1st exp
  241. self.assertEqual(4, len(exp.dataset))
  242. elif exp.current_experience == 15:
  243. # Last mini-exp from 2nd exp
  244. self.assertEqual(8, len(exp.dataset))
  245. else:
  246. # Other mini-exp
  247. self.assertEqual(12, len(exp.dataset))
  248. if tensor_idx >= 100:
  249. ref_tensor_x = experience_2_x
  250. ref_tensor_y = experience_2_y
  251. tensor_idx = 0
  252. for x, y, *_ in exp.dataset:
  253. self.assertTrue(torch.equal(ref_tensor_x[tensor_idx], x))
  254. self.assertTrue(torch.equal(ref_tensor_y[tensor_idx], y))
  255. tensor_idx += 1
  256. exp = data_incremental_instance.test_stream[0]
  257. self.assertEqual(50, len(exp.dataset))
  258. tensor_idx = 0
  259. for x, y, *_ in exp.dataset:
  260. self.assertTrue(torch.equal(test_x[tensor_idx], x))
  261. self.assertTrue(torch.equal(test_y[tensor_idx], y))
  262. tensor_idx += 1
  263. def test_benchmark_with_validation_stream_fixed_size(self):
  264. pattern_shape = (3, 32, 32)
  265. # Definition of training experiences
  266. # Experience 1
  267. experience_1_x = torch.zeros(100, *pattern_shape)
  268. experience_1_y = torch.zeros(100, dtype=torch.long)
  269. # Experience 2
  270. experience_2_x = torch.zeros(80, *pattern_shape)
  271. experience_2_y = torch.ones(80, dtype=torch.long)
  272. # Test experience
  273. test_x = torch.zeros(50, *pattern_shape)
  274. test_y = torch.zeros(50, dtype=torch.long)
  275. initial_benchmark_instance = tensors_benchmark(
  276. train_tensors=[(experience_1_x, experience_1_y),
  277. (experience_2_x, experience_2_y)],
  278. test_tensors=[(test_x, test_y)],
  279. task_labels=[0, 0], # Task label of each train exp
  280. complete_test_set_only=True)
  281. valid_benchmark = benchmark_with_validation_stream(
  282. initial_benchmark_instance, 20, shuffle=False)
  283. self.assertEqual(2, len(valid_benchmark.train_stream))
  284. self.assertEqual(2, len(valid_benchmark.valid_stream))
  285. self.assertEqual(1, len(valid_benchmark.test_stream))
  286. self.assertTrue(valid_benchmark.complete_test_set_only)
  287. self.assertEqual(80, len(valid_benchmark.train_stream[0].dataset))
  288. self.assertEqual(60, len(valid_benchmark.train_stream[1].dataset))
  289. self.assertEqual(20, len(valid_benchmark.valid_stream[0].dataset))
  290. self.assertEqual(20, len(valid_benchmark.valid_stream[1].dataset))
  291. self.assertTrue(
  292. torch.equal(
  293. experience_1_x[:80],
  294. valid_benchmark.train_stream[0].dataset[:][0]))
  295. self.assertTrue(
  296. torch.equal(
  297. experience_2_x[:60],
  298. valid_benchmark.train_stream[1].dataset[:][0]))
  299. self.assertTrue(
  300. torch.equal(
  301. experience_1_y[:80],
  302. valid_benchmark.train_stream[0].dataset[:][1]))
  303. self.assertTrue(
  304. torch.equal(
  305. experience_2_y[:60],
  306. valid_benchmark.train_stream[1].dataset[:][1]))
  307. self.assertTrue(
  308. torch.equal(
  309. experience_1_x[80:],
  310. valid_benchmark.valid_stream[0].dataset[:][0]))
  311. self.assertTrue(
  312. torch.equal(
  313. experience_2_x[60:],
  314. valid_benchmark.valid_stream[1].dataset[:][0]))
  315. self.assertTrue(
  316. torch.equal(
  317. experience_1_y[80:],
  318. valid_benchmark.valid_stream[0].dataset[:][1]))
  319. self.assertTrue(
  320. torch.equal(
  321. experience_2_y[60:],
  322. valid_benchmark.valid_stream[1].dataset[:][1]))
  323. self.assertTrue(
  324. torch.equal(
  325. test_x,
  326. valid_benchmark.test_stream[0].dataset[:][0]))
  327. self.assertTrue(
  328. torch.equal(
  329. test_y,
  330. valid_benchmark.test_stream[0].dataset[:][1]))
  331. def test_benchmark_with_validation_stream_rel_size(self):
  332. pattern_shape = (3, 32, 32)
  333. # Definition of training experiences
  334. # Experience 1
  335. experience_1_x = torch.zeros(100, *pattern_shape)
  336. experience_1_y = torch.zeros(100, dtype=torch.long)
  337. # Experience 2
  338. experience_2_x = torch.zeros(80, *pattern_shape)
  339. experience_2_y = torch.ones(80, dtype=torch.long)
  340. # Test experience
  341. test_x = torch.zeros(50, *pattern_shape)
  342. test_y = torch.zeros(50, dtype=torch.long)
  343. initial_benchmark_instance = tensors_benchmark(
  344. train_tensors=[(experience_1_x, experience_1_y),
  345. (experience_2_x, experience_2_y)],
  346. test_tensors=[(test_x, test_y)],
  347. task_labels=[0, 0], # Task label of each train exp
  348. complete_test_set_only=True)
  349. valid_benchmark = benchmark_with_validation_stream(
  350. initial_benchmark_instance, 0.2, shuffle=False)
  351. expected_rel_1_valid = int(100 * 0.2)
  352. expected_rel_1_train = 100 - expected_rel_1_valid
  353. expected_rel_2_valid = int(80 * 0.2)
  354. expected_rel_2_train = 80 - expected_rel_2_valid
  355. self.assertEqual(2, len(valid_benchmark.train_stream))
  356. self.assertEqual(2, len(valid_benchmark.valid_stream))
  357. self.assertEqual(1, len(valid_benchmark.test_stream))
  358. self.assertTrue(valid_benchmark.complete_test_set_only)
  359. self.assertEqual(
  360. expected_rel_1_train, len(valid_benchmark.train_stream[0].dataset))
  361. self.assertEqual(
  362. expected_rel_2_train, len(valid_benchmark.train_stream[1].dataset))
  363. self.assertEqual(
  364. expected_rel_1_valid, len(valid_benchmark.valid_stream[0].dataset))
  365. self.assertEqual(
  366. expected_rel_2_valid, len(valid_benchmark.valid_stream[1].dataset))
  367. self.assertTrue(
  368. torch.equal(
  369. experience_1_x[:expected_rel_1_train],
  370. valid_benchmark.train_stream[0].dataset[:][0]))
  371. self.assertTrue(
  372. torch.equal(
  373. experience_2_x[:expected_rel_2_train],
  374. valid_benchmark.train_stream[1].dataset[:][0]))
  375. self.assertTrue(
  376. torch.equal(
  377. experience_1_y[:expected_rel_1_train],
  378. valid_benchmark.train_stream[0].dataset[:][1]))
  379. self.assertTrue(
  380. torch.equal(
  381. experience_2_y[:expected_rel_2_train],
  382. valid_benchmark.train_stream[1].dataset[:][1]))
  383. self.assertTrue(
  384. torch.equal(
  385. experience_1_x[expected_rel_1_train:],
  386. valid_benchmark.valid_stream[0].dataset[:][0]))
  387. self.assertTrue(
  388. torch.equal(
  389. experience_2_x[expected_rel_2_train:],
  390. valid_benchmark.valid_stream[1].dataset[:][0]))
  391. self.assertTrue(
  392. torch.equal(
  393. experience_1_y[expected_rel_1_train:],
  394. valid_benchmark.valid_stream[0].dataset[:][1]))
  395. self.assertTrue(
  396. torch.equal(
  397. experience_2_y[expected_rel_2_train:],
  398. valid_benchmark.valid_stream[1].dataset[:][1]))
  399. self.assertTrue(
  400. torch.equal(
  401. test_x,
  402. valid_benchmark.test_stream[0].dataset[:][0]))
  403. self.assertTrue(
  404. torch.equal(
  405. test_y,
  406. valid_benchmark.test_stream[0].dataset[:][1]))
  407. def test_lazy_benchmark_with_validation_stream_fixed_size(self):
  408. lazy_options = [None, True, False]
  409. for lazy_option in lazy_options:
  410. with self.subTest(lazy_option=lazy_option):
  411. pattern_shape = (3, 32, 32)
  412. # Definition of training experiences
  413. # Experience 1
  414. experience_1_x = torch.zeros(100, *pattern_shape)
  415. experience_1_y = torch.zeros(100, dtype=torch.long)
  416. experience_1_dataset = AvalancheTensorDataset(
  417. experience_1_x, experience_1_y)
  418. # Experience 2
  419. experience_2_x = torch.zeros(80, *pattern_shape)
  420. experience_2_y = torch.ones(80, dtype=torch.long)
  421. experience_2_dataset = AvalancheTensorDataset(
  422. experience_2_x, experience_2_y)
  423. # Test experience
  424. test_x = torch.zeros(50, *pattern_shape)
  425. test_y = torch.zeros(50, dtype=torch.long)
  426. experience_test = AvalancheTensorDataset(
  427. test_x, test_y)
  428. def train_gen():
  429. # Lazy generator of the training stream
  430. for dataset in [experience_1_dataset, experience_2_dataset]:
  431. yield dataset
  432. def test_gen():
  433. # Lazy generator of the test stream
  434. for dataset in [experience_test]:
  435. yield dataset
  436. initial_benchmark_instance = create_lazy_generic_benchmark(
  437. train_generator=LazyStreamDefinition(
  438. train_gen(), 2, [0, 0]),
  439. test_generator=LazyStreamDefinition(
  440. test_gen(), 1, [0]),
  441. complete_test_set_only=True,
  442. dataset_type=AvalancheDatasetType.CLASSIFICATION)
  443. valid_benchmark = benchmark_with_validation_stream(
  444. initial_benchmark_instance, 20, shuffle=False,
  445. lazy_splitting=lazy_option)
  446. if lazy_option is None or lazy_option:
  447. expect_laziness = True
  448. else:
  449. expect_laziness = False
  450. self.assertEqual(
  451. expect_laziness,
  452. valid_benchmark.stream_definitions['train'].is_lazy)
  453. self.assertEqual(2, len(valid_benchmark.train_stream))
  454. self.assertEqual(2, len(valid_benchmark.valid_stream))
  455. self.assertEqual(1, len(valid_benchmark.test_stream))
  456. self.assertTrue(valid_benchmark.complete_test_set_only)
  457. maybe_exp = valid_benchmark.stream_definitions[
  458. 'train'].exps_data.get_experience_if_loaded(0)
  459. self.assertEqual(expect_laziness, maybe_exp is None)
  460. self.assertEqual(
  461. 80, len(valid_benchmark.train_stream[0].dataset))
  462. maybe_exp = valid_benchmark.stream_definitions[
  463. 'train'].exps_data.get_experience_if_loaded(1)
  464. self.assertEqual(expect_laziness, maybe_exp is None)
  465. self.assertEqual(
  466. 60, len(valid_benchmark.train_stream[1].dataset))
  467. maybe_exp = valid_benchmark.stream_definitions[
  468. 'valid'].exps_data.get_experience_if_loaded(0)
  469. self.assertEqual(expect_laziness, maybe_exp is None)
  470. self.assertEqual(
  471. 20, len(valid_benchmark.valid_stream[0].dataset))
  472. maybe_exp = valid_benchmark.stream_definitions[
  473. 'valid'].exps_data.get_experience_if_loaded(1)
  474. self.assertEqual(expect_laziness, maybe_exp is None)
  475. self.assertEqual(
  476. 20, len(valid_benchmark.valid_stream[1].dataset))
  477. self.assertIsNotNone(
  478. valid_benchmark.stream_definitions[
  479. 'train'].exps_data.get_experience_if_loaded(0))
  480. self.assertIsNotNone(
  481. valid_benchmark.stream_definitions[
  482. 'valid'].exps_data.get_experience_if_loaded(0))
  483. self.assertIsNotNone(
  484. valid_benchmark.stream_definitions[
  485. 'train'].exps_data.get_experience_if_loaded(1))
  486. self.assertIsNotNone(
  487. valid_benchmark.stream_definitions[
  488. 'valid'].exps_data.get_experience_if_loaded(1))
  489. self.assertTrue(
  490. torch.equal(
  491. experience_1_x[:80],
  492. valid_benchmark.train_stream[0].dataset[:][0]))
  493. self.assertTrue(
  494. torch.equal(
  495. experience_2_x[:60],
  496. valid_benchmark.train_stream[1].dataset[:][0]))
  497. self.assertTrue(
  498. torch.equal(
  499. experience_1_y[:80],
  500. valid_benchmark.train_stream[0].dataset[:][1]))
  501. self.assertTrue(
  502. torch.equal(
  503. experience_2_y[:60],
  504. valid_benchmark.train_stream[1].dataset[:][1]))
  505. self.assertTrue(
  506. torch.equal(
  507. experience_1_x[80:],
  508. valid_benchmark.valid_stream[0].dataset[:][0]))
  509. self.assertTrue(
  510. torch.equal(
  511. experience_2_x[60:],
  512. valid_benchmark.valid_stream[1].dataset[:][0]))
  513. self.assertTrue(
  514. torch.equal(
  515. experience_1_y[80:],
  516. valid_benchmark.valid_stream[0].dataset[:][1]))
  517. self.assertTrue(
  518. torch.equal(
  519. experience_2_y[60:],
  520. valid_benchmark.valid_stream[1].dataset[:][1]))
  521. self.assertTrue(
  522. torch.equal(
  523. test_x,
  524. valid_benchmark.test_stream[0].dataset[:][0]))
  525. self.assertTrue(
  526. torch.equal(
  527. test_y,
  528. valid_benchmark.test_stream[0].dataset[:][1]))