test_avalanche_dataset.py 77 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039
  1. import unittest
  2. from os.path import expanduser
  3. from avalanche.models import SimpleMLP
  4. from torch.optim import SGD
  5. from torch.nn import CrossEntropyLoss
  6. from avalanche.training.strategies import Naive
  7. from avalanche.benchmarks.generators import dataset_benchmark
  8. import PIL
  9. import torch
  10. from PIL import ImageChops
  11. from PIL.Image import Image
  12. from torch import Tensor
  13. from torch.utils.data import TensorDataset, Subset, ConcatDataset
  14. from torchvision.datasets import MNIST
  15. from torchvision.transforms import ToTensor, RandomCrop, ToPILImage, Compose, \
  16. Lambda, CenterCrop
  17. from typing import List
  18. from avalanche.benchmarks.scenarios.generic_benchmark_creation import \
  19. create_generic_benchmark_from_tensor_lists
  20. from avalanche.benchmarks.utils import AvalancheDataset, \
  21. AvalancheSubset, AvalancheConcatDataset, AvalancheDatasetType, \
  22. AvalancheTensorDataset
  23. from avalanche.benchmarks.utils.dataset_utils import ConstantSequence
  24. from avalanche.training.utils import load_all_dataset
  25. import random
  26. import numpy as np
  27. def pil_images_equal(img_a, img_b):
  28. diff = ImageChops.difference(img_a, img_b)
  29. return not diff.getbbox()
  30. class AvalancheDatasetTests(unittest.TestCase):
  31. def test_mnist_no_transforms(self):
  32. dataset = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  33. download=True)
  34. x, y = dataset[0]
  35. self.assertIsInstance(x, Image)
  36. self.assertEqual([x.width, x.height], [28, 28])
  37. self.assertIsInstance(y, int)
  38. def test_mnist_native_transforms(self):
  39. dataset = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  40. download=True, transform=ToTensor())
  41. x, y = dataset[0]
  42. self.assertIsInstance(x, Tensor)
  43. self.assertEqual(x.shape, (1, 28, 28))
  44. self.assertIsInstance(y, int)
  45. def test_avalanche_dataset_transform(self):
  46. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  47. download=True)
  48. x, y = dataset_mnist[0]
  49. dataset = AvalancheDataset(dataset_mnist, transform=ToTensor())
  50. x2, y2, t2 = dataset[0]
  51. self.assertIsInstance(x2, Tensor)
  52. self.assertIsInstance(y2, int)
  53. self.assertIsInstance(t2, int)
  54. self.assertEqual(0, t2)
  55. self.assertTrue(torch.equal(ToTensor()(x), x2))
  56. self.assertEqual(y, y2)
  57. def test_avalanche_dataset_slice(self):
  58. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  59. download=True)
  60. x0, y0 = dataset_mnist[0]
  61. x1, y1 = dataset_mnist[1]
  62. dataset = AvalancheDataset(dataset_mnist, transform=ToTensor())
  63. x2, y2, t2 = dataset[:2]
  64. self.assertIsInstance(x2, Tensor)
  65. self.assertIsInstance(y2, Tensor)
  66. self.assertIsInstance(t2, Tensor)
  67. self.assertTrue(torch.equal(ToTensor()(x0), x2[0]))
  68. self.assertTrue(torch.equal(ToTensor()(x1), x2[1]))
  69. self.assertEqual(y0, y2[0].item())
  70. self.assertEqual(y1, y2[1].item())
  71. self.assertEqual(0, t2[0].item())
  72. self.assertEqual(0, t2[1].item())
  73. def test_avalanche_dataset_indexing(self):
  74. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  75. download=True)
  76. x0, y0 = dataset_mnist[0]
  77. x1, y1 = dataset_mnist[5]
  78. dataset = AvalancheDataset(dataset_mnist, transform=ToTensor())
  79. x2, y2, t2 = dataset[0, 5]
  80. self.assertIsInstance(x2, Tensor)
  81. self.assertIsInstance(y2, Tensor)
  82. self.assertIsInstance(t2, Tensor)
  83. self.assertTrue(torch.equal(ToTensor()(x0), x2[0]))
  84. self.assertTrue(torch.equal(ToTensor()(x1), x2[1]))
  85. self.assertEqual(y0, y2[0].item())
  86. self.assertEqual(y1, y2[1].item())
  87. self.assertEqual(0, t2[0].item())
  88. self.assertEqual(0, t2[1].item())
  89. def test_avalanche_dataset_composition(self):
  90. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  91. download=True, transform=RandomCrop(16))
  92. x, y = dataset_mnist[0]
  93. self.assertIsInstance(x, Image)
  94. self.assertEqual([x.width, x.height], [16, 16])
  95. self.assertIsInstance(y, int)
  96. dataset = AvalancheDataset(
  97. dataset_mnist, transform=ToTensor(),
  98. target_transform=lambda target: -1)
  99. x2, y2, t2 = dataset[0]
  100. self.assertIsInstance(x2, Tensor)
  101. self.assertEqual(x2.shape, (1, 16, 16))
  102. self.assertIsInstance(y2, int)
  103. self.assertEqual(y2, -1)
  104. self.assertIsInstance(t2, int)
  105. self.assertEqual(0, t2)
  106. def test_avalanche_dataset_add(self):
  107. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  108. download=True, transform=CenterCrop(16))
  109. dataset1 = AvalancheDataset(
  110. dataset_mnist, transform=ToTensor(),
  111. target_transform=lambda target: -1)
  112. dataset2 = AvalancheDataset(
  113. dataset_mnist, target_transform=lambda target: -2,
  114. task_labels=ConstantSequence(2, len(dataset_mnist)))
  115. dataset3 = dataset1 + dataset2
  116. self.assertEqual(len(dataset_mnist)*2, len(dataset3))
  117. x1, y1, t1 = dataset1[0]
  118. x2, y2, t2 = dataset2[0]
  119. x3, y3, t3 = dataset3[0]
  120. x3_2, y3_2, t3_2 = dataset3[len(dataset_mnist)]
  121. self.assertIsInstance(x1, Tensor)
  122. self.assertEqual(x1.shape, (1, 16, 16))
  123. self.assertEqual(-1, y1)
  124. self.assertEqual(0, t1)
  125. self.assertIsInstance(x2, PIL.Image.Image)
  126. self.assertEqual(x2.size, (16, 16))
  127. self.assertEqual(-2, y2)
  128. self.assertEqual(2, t2)
  129. self.assertEqual((y1, t1), (y3, t3))
  130. self.assertEqual(16 * 16, torch.sum(torch.eq(x1, x3)).item())
  131. self.assertEqual((y2, t2), (y3_2, t3_2))
  132. self.assertTrue(pil_images_equal(x2, x3_2))
  133. def test_avalanche_dataset_radd(self):
  134. dataset_mnist = MNIST(
  135. expanduser("~") + "/.avalanche/data/mnist/",
  136. download=True,
  137. transform=CenterCrop(16))
  138. dataset1 = AvalancheDataset(
  139. dataset_mnist, transform=ToTensor(),
  140. target_transform=lambda target: -1)
  141. dataset2 = dataset_mnist + dataset1
  142. self.assertIsInstance(dataset2, AvalancheDataset)
  143. self.assertEqual(len(dataset_mnist) * 2, len(dataset2))
  144. dataset3 = dataset_mnist + dataset1 + dataset_mnist
  145. self.assertIsInstance(dataset3, AvalancheDataset)
  146. self.assertEqual(len(dataset_mnist) * 3, len(dataset3))
  147. dataset4 = dataset_mnist + dataset_mnist + dataset1
  148. self.assertIsInstance(dataset4, AvalancheDataset)
  149. self.assertEqual(len(dataset_mnist) * 3, len(dataset4))
  150. def test_dataset_add_monkey_patch_vanilla_behaviour(self):
  151. dataset_mnist = MNIST(
  152. expanduser("~") + "/.avalanche/data/mnist/",
  153. download=True,
  154. transform=CenterCrop(16))
  155. dataset_mnist2 = MNIST(
  156. expanduser("~") + "/.avalanche/data/mnist/",
  157. download=True,
  158. transform=CenterCrop(16))
  159. dataset = dataset_mnist + dataset_mnist2
  160. self.assertIsInstance(dataset, ConcatDataset)
  161. self.assertEqual(len(dataset_mnist) * 2, len(dataset))
  162. def test_avalanche_dataset_uniform_task_labels(self):
  163. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  164. download=True)
  165. x, y = dataset_mnist[0]
  166. dataset = AvalancheDataset(dataset_mnist, transform=ToTensor(),
  167. task_labels=[1] * len(dataset_mnist))
  168. x2, y2, t2 = dataset[0]
  169. self.assertIsInstance(x2, Tensor)
  170. self.assertIsInstance(y2, int)
  171. self.assertIsInstance(t2, int)
  172. self.assertEqual(1, t2)
  173. self.assertTrue(torch.equal(ToTensor()(x), x2))
  174. self.assertEqual(y, y2)
  175. self.assertListEqual([1] * len(dataset_mnist),
  176. list(dataset.targets_task_labels))
  177. subset_task1 = dataset.task_set[1]
  178. self.assertIsInstance(subset_task1, AvalancheDataset)
  179. self.assertEqual(len(dataset), len(subset_task1))
  180. with self.assertRaises(KeyError):
  181. subset_task0 = dataset.task_set[0]
  182. def test_avalanche_dataset_tensor_task_labels(self):
  183. x = torch.rand(32, 10)
  184. y = torch.rand(32, 10)
  185. t = torch.ones(32) # Single task
  186. dataset = AvalancheTensorDataset(x, y, targets=1, task_labels=t)
  187. x2, y2, t2 = dataset[:]
  188. self.assertIsInstance(x2, Tensor)
  189. self.assertIsInstance(y2, Tensor)
  190. self.assertIsInstance(t2, Tensor)
  191. self.assertTrue(torch.equal(x, x2))
  192. self.assertTrue(torch.equal(y, y2))
  193. self.assertTrue(torch.equal(t.to(int), t2))
  194. self.assertListEqual([1] * 32,
  195. list(dataset.targets_task_labels))
  196. # Regression test for #654
  197. self.assertEqual(1, len(dataset.task_set))
  198. subset_task1 = dataset.task_set[1]
  199. self.assertIsInstance(subset_task1, AvalancheDataset)
  200. self.assertEqual(len(dataset), len(subset_task1))
  201. with self.assertRaises(KeyError):
  202. subset_task0 = dataset.task_set[0]
  203. with self.assertRaises(KeyError):
  204. subset_task0 = dataset.task_set[2]
  205. # Check single instance types
  206. x2, y2, t2 = dataset[0]
  207. self.assertIsInstance(x2, Tensor)
  208. self.assertIsInstance(y2, Tensor)
  209. self.assertIsInstance(t2, int)
  210. def test_avalanche_dataset_uniform_task_labels_simple_def(self):
  211. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  212. download=True)
  213. dataset = AvalancheDataset(dataset_mnist, transform=ToTensor(),
  214. task_labels=1)
  215. _, _, t2 = dataset[0]
  216. self.assertIsInstance(t2, int)
  217. self.assertEqual(1, t2)
  218. self.assertListEqual([1] * len(dataset_mnist),
  219. list(dataset.targets_task_labels))
  220. subset_task1 = dataset.task_set[1]
  221. self.assertIsInstance(subset_task1, AvalancheDataset)
  222. self.assertEqual(len(dataset), len(subset_task1))
  223. with self.assertRaises(KeyError):
  224. subset_task0 = dataset.task_set[0]
  225. def test_avalanche_dataset_mixed_task_labels(self):
  226. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  227. download=True)
  228. x, y = dataset_mnist[0]
  229. random_task_labels = [random.randint(0, 10)
  230. for _ in range(len(dataset_mnist))]
  231. dataset = AvalancheDataset(dataset_mnist, transform=ToTensor(),
  232. task_labels=random_task_labels)
  233. x2, y2, t2 = dataset[0]
  234. self.assertIsInstance(x2, Tensor)
  235. self.assertIsInstance(y2, int)
  236. self.assertIsInstance(t2, int)
  237. self.assertEqual(random_task_labels[0], t2)
  238. self.assertTrue(torch.equal(ToTensor()(x), x2))
  239. self.assertEqual(y, y2)
  240. self.assertListEqual(random_task_labels,
  241. list(dataset.targets_task_labels))
  242. u_labels, counts = np.unique(random_task_labels, return_counts=True)
  243. for i, task_label in enumerate(u_labels.tolist()):
  244. subset_task = dataset.task_set[task_label]
  245. self.assertIsInstance(subset_task, AvalancheDataset)
  246. self.assertEqual(int(counts[i]), len(subset_task))
  247. unique_task_labels = list(subset_task.targets_task_labels)
  248. self.assertListEqual([task_label] * int(counts[i]),
  249. unique_task_labels)
  250. with self.assertRaises(KeyError):
  251. subset_task11 = dataset.task_set[11]
  252. def test_avalanche_tensor_dataset_task_labels_train(self):
  253. tr_ds = [AvalancheTensorDataset(
  254. torch.randn(10, 4),
  255. torch.randint(0, 3, (10,)),
  256. dataset_type=AvalancheDatasetType.CLASSIFICATION,
  257. task_labels=torch.randint(0, 5, (10,)).tolist()) for i in range(3)]
  258. ts_ds = [AvalancheTensorDataset(
  259. torch.randn(10, 4), torch.randint(0, 3, (10,)),
  260. dataset_type=AvalancheDatasetType.CLASSIFICATION,
  261. task_labels=torch.randint(0, 5, (10,)).tolist()) for i in range(3)]
  262. benchmark = dataset_benchmark(train_datasets=tr_ds, test_datasets=ts_ds)
  263. model = SimpleMLP(input_size=4, num_classes=3)
  264. cl_strategy = Naive(model, SGD(model.parameters(), lr=0.001,
  265. momentum=0.9),
  266. CrossEntropyLoss(), train_mb_size=5,
  267. train_epochs=1, eval_mb_size=5,
  268. device='cpu', evaluator=None)
  269. exp = []
  270. for i, experience in enumerate(benchmark.train_stream):
  271. exp.append(i)
  272. cl_strategy.train(experience)
  273. self.assertEqual(len(exp), 3)
  274. def test_avalanche_dataset_task_labels_inheritance(self):
  275. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  276. download=True)
  277. random_task_labels = [random.randint(0, 10)
  278. for _ in range(len(dataset_mnist))]
  279. dataset_orig = AvalancheDataset(dataset_mnist, transform=ToTensor(),
  280. task_labels=random_task_labels)
  281. dataset_child = AvalancheDataset(dataset_orig)
  282. x2, y2, t2 = dataset_orig[0]
  283. x3, y3, t3 = dataset_child[0]
  284. self.assertIsInstance(t2, int)
  285. self.assertEqual(random_task_labels[0], t2)
  286. self.assertIsInstance(t3, int)
  287. self.assertEqual(random_task_labels[0], t3)
  288. self.assertListEqual(random_task_labels,
  289. list(dataset_orig.targets_task_labels))
  290. self.assertListEqual(random_task_labels,
  291. list(dataset_child.targets_task_labels))
  292. def test_avalanche_dataset_tensor_dataset_input(self):
  293. train_x = torch.rand(500, 3, 28, 28)
  294. train_y = torch.zeros(500)
  295. test_x = torch.rand(200, 3, 28, 28)
  296. test_y = torch.ones(200)
  297. train = TensorDataset(train_x, train_y)
  298. test = TensorDataset(test_x, test_y)
  299. train_dataset = AvalancheDataset(train)
  300. test_dataset = AvalancheDataset(test)
  301. self.assertEqual(500, len(train_dataset))
  302. self.assertEqual(200, len(test_dataset))
  303. x, y, t = train_dataset[0]
  304. self.assertIsInstance(x, Tensor)
  305. self.assertEqual(0, y)
  306. self.assertEqual(0, t)
  307. x2, y2, t2 = test_dataset[0]
  308. self.assertIsInstance(x2, Tensor)
  309. self.assertEqual(1, y2)
  310. self.assertEqual(0, t2)
  311. def test_avalanche_dataset_multiple_outputs_and_float_y(self):
  312. train_x = torch.rand(500, 3, 28, 28)
  313. train_y = torch.zeros(500)
  314. train_z = torch.ones(500)
  315. test_x = torch.rand(200, 3, 28, 28)
  316. test_y = torch.ones(200)
  317. test_z = torch.full((200,), 5)
  318. train = TensorDataset(train_x, train_y, train_z)
  319. test = TensorDataset(test_x, test_y, test_z)
  320. train_dataset = AvalancheDataset(train)
  321. test_dataset = AvalancheDataset(test)
  322. self.assertEqual(500, len(train_dataset))
  323. self.assertEqual(200, len(test_dataset))
  324. x, y, z, t = train_dataset[0]
  325. self.assertIsInstance(x, Tensor)
  326. self.assertEqual(0, y)
  327. self.assertEqual(1, z)
  328. self.assertEqual(0, t)
  329. x2, y2, z2, t2 = test_dataset[0]
  330. self.assertIsInstance(x2, Tensor)
  331. self.assertEqual(1, y2)
  332. self.assertEqual(5, z2)
  333. self.assertEqual(0, t2)
  334. def test_avalanche_dataset_from_pytorch_subset(self):
  335. tensor_x = torch.rand(500, 3, 28, 28)
  336. tensor_y = torch.randint(0, 100, (500,))
  337. whole_dataset = TensorDataset(tensor_x, tensor_y)
  338. train = Subset(whole_dataset, indices=list(range(400)))
  339. test = Subset(whole_dataset, indices=list(range(400, 500)))
  340. train_dataset = AvalancheDataset(train)
  341. test_dataset = AvalancheDataset(test)
  342. self.assertEqual(400, len(train_dataset))
  343. self.assertEqual(100, len(test_dataset))
  344. x, y, t = train_dataset[0]
  345. self.assertIsInstance(x, Tensor)
  346. self.assertTrue(torch.equal(tensor_x[0], x))
  347. self.assertTrue(torch.equal(tensor_y[0], y))
  348. self.assertEqual(0, t)
  349. self.assertTrue(torch.equal(torch.as_tensor(train_dataset.targets),
  350. tensor_y[:400]))
  351. x2, y2, t2 = test_dataset[0]
  352. self.assertIsInstance(x2, Tensor)
  353. self.assertTrue(torch.equal(tensor_x[400], x2))
  354. self.assertTrue(torch.equal(tensor_y[400], y2))
  355. self.assertEqual(0, t2)
  356. self.assertTrue(torch.equal(torch.as_tensor(test_dataset.targets),
  357. tensor_y[400:]))
  358. def test_avalanche_dataset_from_pytorch_concat_dataset(self):
  359. tensor_x = torch.rand(500, 3, 28, 28)
  360. tensor_x2 = torch.rand(300, 3, 28, 28)
  361. tensor_y = torch.randint(0, 100, (500,))
  362. tensor_y2 = torch.randint(0, 100, (300,))
  363. dataset1 = TensorDataset(tensor_x, tensor_y)
  364. dataset2 = TensorDataset(tensor_x2, tensor_y2)
  365. concat_dataset = ConcatDataset((dataset1, dataset2))
  366. av_dataset = AvalancheDataset(concat_dataset)
  367. self.assertEqual(500, len(dataset1))
  368. self.assertEqual(300, len(dataset2))
  369. x, y, t = av_dataset[0]
  370. x2, y2, t2 = av_dataset[500]
  371. self.assertIsInstance(x, Tensor)
  372. self.assertTrue(torch.equal(tensor_x[0], x))
  373. self.assertTrue(torch.equal(tensor_y[0], y))
  374. self.assertEqual(0, t)
  375. self.assertIsInstance(x2, Tensor)
  376. self.assertTrue(torch.equal(tensor_x2[0], x2))
  377. self.assertTrue(torch.equal(tensor_y2[0], y2))
  378. self.assertEqual(0, t2)
  379. self.assertTrue(torch.equal(torch.as_tensor(av_dataset.targets),
  380. torch.cat((tensor_y, tensor_y2))))
  381. def test_avalanche_dataset_from_chained_pytorch_concat_dataset(self):
  382. tensor_x = torch.rand(500, 3, 28, 28)
  383. tensor_x2 = torch.rand(300, 3, 28, 28)
  384. tensor_x3 = torch.rand(200, 3, 28, 28)
  385. tensor_y = torch.randint(0, 100, (500,))
  386. tensor_y2 = torch.randint(0, 100, (300,))
  387. tensor_y3 = torch.randint(0, 100, (200,))
  388. dataset1 = TensorDataset(tensor_x, tensor_y)
  389. dataset2 = TensorDataset(tensor_x2, tensor_y2)
  390. dataset3 = TensorDataset(tensor_x3, tensor_y3)
  391. concat_dataset = ConcatDataset((dataset1, dataset2))
  392. concat_dataset2 = ConcatDataset((concat_dataset, dataset3))
  393. av_dataset = AvalancheDataset(concat_dataset2)
  394. self.assertEqual(500, len(dataset1))
  395. self.assertEqual(300, len(dataset2))
  396. x, y, t = av_dataset[0]
  397. x2, y2, t2 = av_dataset[500]
  398. x3, y3, t3 = av_dataset[800]
  399. self.assertIsInstance(x, Tensor)
  400. self.assertTrue(torch.equal(tensor_x[0], x))
  401. self.assertTrue(torch.equal(tensor_y[0], y))
  402. self.assertEqual(0, t)
  403. self.assertIsInstance(x2, Tensor)
  404. self.assertTrue(torch.equal(tensor_x2[0], x2))
  405. self.assertTrue(torch.equal(tensor_y2[0], y2))
  406. self.assertEqual(0, t2)
  407. self.assertIsInstance(x3, Tensor)
  408. self.assertTrue(torch.equal(tensor_x3[0], x3))
  409. self.assertTrue(torch.equal(tensor_y3[0], y3))
  410. self.assertEqual(0, t3)
  411. self.assertTrue(torch.equal(
  412. torch.as_tensor(av_dataset.targets),
  413. torch.cat((tensor_y, tensor_y2, tensor_y3))))
  414. def test_avalanche_dataset_from_chained_pytorch_subsets(self):
  415. tensor_x = torch.rand(500, 3, 28, 28)
  416. tensor_y = torch.randint(0, 100, (500,))
  417. whole_dataset = TensorDataset(tensor_x, tensor_y)
  418. subset1 = Subset(whole_dataset, indices=list(range(400, 500)))
  419. subset2 = Subset(subset1, indices=[5, 7, 0])
  420. dataset = AvalancheDataset(subset2)
  421. self.assertEqual(3, len(dataset))
  422. x, y, t = dataset[0]
  423. self.assertIsInstance(x, Tensor)
  424. self.assertTrue(torch.equal(tensor_x[405], x))
  425. self.assertTrue(torch.equal(tensor_y[405], y))
  426. self.assertEqual(0, t)
  427. self.assertTrue(
  428. torch.equal(
  429. torch.as_tensor(dataset.targets),
  430. torch.as_tensor([tensor_y[405], tensor_y[407], tensor_y[400]])
  431. )
  432. )
  433. def test_avalanche_dataset_from_chained_pytorch_concat_subset_dataset(self):
  434. tensor_x = torch.rand(200, 3, 28, 28)
  435. tensor_x2 = torch.rand(100, 3, 28, 28)
  436. tensor_y = torch.randint(0, 100, (200,))
  437. tensor_y2 = torch.randint(0, 100, (100,))
  438. dataset1 = TensorDataset(tensor_x, tensor_y)
  439. dataset2 = TensorDataset(tensor_x2, tensor_y2)
  440. indices = [random.randint(0, 299) for _ in range(1000)]
  441. concat_dataset = ConcatDataset((dataset1, dataset2))
  442. subset = Subset(concat_dataset, indices)
  443. av_dataset = AvalancheDataset(subset)
  444. self.assertEqual(200, len(dataset1))
  445. self.assertEqual(100, len(dataset2))
  446. self.assertEqual(1000, len(av_dataset))
  447. for idx in range(1000):
  448. orig_idx = indices[idx]
  449. if orig_idx < 200:
  450. expected_x, expected_y = dataset1[orig_idx]
  451. else:
  452. expected_x, expected_y = dataset2[orig_idx-200]
  453. x, y, t = av_dataset[idx]
  454. self.assertIsInstance(x, Tensor)
  455. self.assertTrue(torch.equal(expected_x, x))
  456. self.assertTrue(torch.equal(expected_y, y))
  457. self.assertEqual(0, t)
  458. self.assertEqual(int(expected_y), int(av_dataset.targets[idx]))
  459. def test_avalanche_dataset_from_chained_pytorch_datasets(self):
  460. tensor_x = torch.rand(200, 3, 28, 28)
  461. tensor_x2 = torch.rand(100, 3, 28, 28)
  462. tensor_y = torch.randint(0, 100, (200,))
  463. tensor_y2 = torch.randint(0, 100, (100,))
  464. dataset1 = TensorDataset(tensor_x, tensor_y)
  465. dataset1_sub = Subset(dataset1, range(199, -1, -1))
  466. dataset2 = TensorDataset(tensor_x2, tensor_y2)
  467. indices = [random.randint(0, 299) for _ in range(1000)]
  468. concat_dataset = ConcatDataset((dataset1_sub, dataset2))
  469. subset = Subset(concat_dataset, indices)
  470. av_dataset = AvalancheDataset(subset)
  471. self.assertEqual(200, len(dataset1_sub))
  472. self.assertEqual(100, len(dataset2))
  473. self.assertEqual(1000, len(av_dataset))
  474. for idx in range(1000):
  475. orig_idx = indices[idx]
  476. if orig_idx < 200:
  477. orig_idx = range(199, -1, -1)[orig_idx]
  478. expected_x, expected_y = dataset1[orig_idx]
  479. else:
  480. expected_x, expected_y = dataset2[orig_idx-200]
  481. x, y, t = av_dataset[idx]
  482. self.assertIsInstance(x, Tensor)
  483. self.assertTrue(torch.equal(expected_x, x))
  484. self.assertTrue(torch.equal(expected_y, y))
  485. self.assertEqual(0, t)
  486. self.assertEqual(int(expected_y), int(av_dataset.targets[idx]))
  487. def test_avalanche_dataset_from_chained_pytorch_datasets_task_labels(self):
  488. tensor_x = torch.rand(200, 3, 28, 28)
  489. tensor_x2 = torch.rand(100, 3, 28, 28)
  490. tensor_y = torch.randint(0, 100, (200,))
  491. tensor_y2 = torch.randint(0, 100, (100,))
  492. tensor_t = torch.randint(0, 100, (200,))
  493. tensor_t2 = torch.randint(0, 100, (100,))
  494. dataset1 = AvalancheTensorDataset(
  495. tensor_x, tensor_y, task_labels=tensor_t)
  496. dataset1_sub = Subset(dataset1, range(199, -1, -1))
  497. dataset2 = AvalancheDataset(
  498. TensorDataset(tensor_x2, tensor_y2), task_labels=tensor_t2)
  499. indices = [random.randint(0, 299) for _ in range(1000)]
  500. concat_dataset = ConcatDataset((dataset1_sub, dataset2))
  501. subset = Subset(concat_dataset, indices)
  502. av_dataset = AvalancheDataset(subset)
  503. self.assertEqual(200, len(dataset1_sub))
  504. self.assertEqual(100, len(dataset2))
  505. self.assertEqual(1000, len(av_dataset))
  506. for idx in range(1000):
  507. orig_idx = indices[idx]
  508. if orig_idx < 200:
  509. orig_idx = range(199, -1, -1)[orig_idx]
  510. expected_x = tensor_x[orig_idx]
  511. expected_y = tensor_y[orig_idx]
  512. expected_t = tensor_t[orig_idx]
  513. else:
  514. orig_idx -= 200
  515. expected_x = tensor_x2[orig_idx]
  516. expected_y = tensor_y2[orig_idx]
  517. expected_t = tensor_t2[orig_idx]
  518. x, y, t = av_dataset[idx]
  519. self.assertIsInstance(x, Tensor)
  520. self.assertTrue(torch.equal(expected_x, x))
  521. self.assertTrue(torch.equal(expected_y, y))
  522. self.assertIsInstance(t, int)
  523. self.assertEqual(int(expected_t), int(t))
  524. self.assertEqual(int(expected_y), int(av_dataset.targets[idx]))
  525. def test_avalanche_dataset_collate_fn(self):
  526. tensor_x = torch.rand(500, 3, 28, 28)
  527. tensor_y = torch.randint(0, 100, (500,))
  528. tensor_z = torch.randint(0, 100, (500,))
  529. def my_collate_fn(patterns):
  530. x_values = torch.stack([pat[0] for pat in patterns], 0)
  531. y_values = torch.tensor([pat[1] for pat in patterns]) + 1
  532. z_values = torch.tensor([-1 for _ in patterns])
  533. t_values = torch.tensor([pat[3] for pat in patterns])
  534. return x_values, y_values, z_values, t_values
  535. whole_dataset = TensorDataset(tensor_x, tensor_y, tensor_z)
  536. dataset = AvalancheDataset(whole_dataset, collate_fn=my_collate_fn)
  537. x, y, z, t = dataset[0]
  538. self.assertIsInstance(x, Tensor)
  539. self.assertTrue(torch.equal(tensor_x[0], x))
  540. self.assertTrue(torch.equal(tensor_y[0], y))
  541. self.assertEqual(0, t)
  542. x2, y2, z2, t2 = dataset[0:5]
  543. self.assertIsInstance(x2, Tensor)
  544. self.assertTrue(torch.equal(tensor_x[0:5], x2))
  545. self.assertTrue(torch.equal(tensor_y[0:5]+1, y2))
  546. self.assertTrue(torch.equal(torch.full((5,), -1, dtype=torch.long), z2))
  547. self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t2))
  548. inherited = AvalancheDataset(dataset)
  549. x3, y3, z3, t3 = inherited[0:5]
  550. self.assertIsInstance(x3, Tensor)
  551. self.assertTrue(torch.equal(tensor_x[0:5], x3))
  552. self.assertTrue(torch.equal(tensor_y[0:5] + 1, y3))
  553. self.assertTrue(torch.equal(torch.full((5,), -1, dtype=torch.long), z3))
  554. self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t3))
  555. with self.assertRaises(ValueError):
  556. # Can't define a custom collate when dataset_type != UNDEFINED
  557. bad_definition = AvalancheDataset(
  558. dataset, dataset_type=AvalancheDatasetType.CLASSIFICATION,
  559. collate_fn=my_collate_fn)
  560. def test_avalanche_dataset_collate_fn_inheritance(self):
  561. tensor_x = torch.rand(200, 3, 28, 28)
  562. tensor_y = torch.randint(0, 100, (200,))
  563. tensor_z = torch.randint(0, 100, (200,))
  564. def my_collate_fn(patterns):
  565. x_values = torch.stack([pat[0] for pat in patterns], 0)
  566. y_values = torch.tensor([pat[1] for pat in patterns]) + 1
  567. z_values = torch.tensor([-1 for _ in patterns])
  568. t_values = torch.tensor([pat[3] for pat in patterns])
  569. return x_values, y_values, z_values, t_values
  570. def my_collate_fn2(patterns):
  571. x_values = torch.stack([pat[0] for pat in patterns], 0)
  572. y_values = torch.tensor([pat[1] for pat in patterns]) + 2
  573. z_values = torch.tensor([-2 for _ in patterns])
  574. t_values = torch.tensor([pat[3] for pat in patterns])
  575. return x_values, y_values, z_values, t_values
  576. whole_dataset = TensorDataset(tensor_x, tensor_y, tensor_z)
  577. dataset = AvalancheDataset(whole_dataset, collate_fn=my_collate_fn)
  578. inherited = AvalancheDataset(dataset, collate_fn=my_collate_fn2) # Ok
  579. x, y, z, t = inherited[0:5]
  580. self.assertIsInstance(x, Tensor)
  581. self.assertTrue(torch.equal(tensor_x[0:5], x))
  582. self.assertTrue(torch.equal(tensor_y[0:5] + 2, y))
  583. self.assertTrue(torch.equal(torch.full((5,), -2, dtype=torch.long), z))
  584. self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t))
  585. classification_dataset = AvalancheDataset(
  586. whole_dataset, dataset_type=AvalancheDatasetType.CLASSIFICATION)
  587. with self.assertRaises(ValueError):
  588. bad_inherited = AvalancheDataset(
  589. classification_dataset, collate_fn=my_collate_fn)
  590. ok_inherited_classification = AvalancheDataset(classification_dataset)
  591. self.assertEqual(AvalancheDatasetType.CLASSIFICATION,
  592. ok_inherited_classification.dataset_type)
  593. def test_avalanche_concat_dataset_collate_fn_inheritance(self):
  594. tensor_x = torch.rand(200, 3, 28, 28)
  595. tensor_y = torch.randint(0, 100, (200,))
  596. tensor_z = torch.randint(0, 100, (200,))
  597. tensor_x2 = torch.rand(200, 3, 28, 28)
  598. tensor_y2 = torch.randint(0, 100, (200,))
  599. tensor_z2 = torch.randint(0, 100, (200,))
  600. def my_collate_fn(patterns):
  601. x_values = torch.stack([pat[0] for pat in patterns], 0)
  602. y_values = torch.tensor([pat[1] for pat in patterns]) + 1
  603. z_values = torch.tensor([-1 for _ in patterns])
  604. t_values = torch.tensor([pat[3] for pat in patterns])
  605. return x_values, y_values, z_values, t_values
  606. def my_collate_fn2(patterns):
  607. x_values = torch.stack([pat[0] for pat in patterns], 0)
  608. y_values = torch.tensor([pat[1] for pat in patterns]) + 2
  609. z_values = torch.tensor([-2 for _ in patterns])
  610. t_values = torch.tensor([pat[3] for pat in patterns])
  611. return x_values, y_values, z_values, t_values
  612. dataset1 = TensorDataset(tensor_x, tensor_y, tensor_z)
  613. dataset2 = AvalancheTensorDataset(tensor_x2, tensor_y2, tensor_z2,
  614. collate_fn=my_collate_fn)
  615. concat = AvalancheConcatDataset([dataset1, dataset2],
  616. collate_fn=my_collate_fn2) # Ok
  617. x, y, z, t = dataset2[0:5]
  618. self.assertIsInstance(x, Tensor)
  619. self.assertTrue(torch.equal(tensor_x2[0:5], x))
  620. self.assertTrue(torch.equal(tensor_y2[0:5] + 1, y))
  621. self.assertTrue(torch.equal(torch.full((5,), -1, dtype=torch.long), z))
  622. self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t))
  623. x2, y2, z2, t2 = concat[0:5]
  624. self.assertIsInstance(x2, Tensor)
  625. self.assertTrue(torch.equal(tensor_x[0:5], x2))
  626. self.assertTrue(torch.equal(tensor_y[0:5] + 2, y2))
  627. self.assertTrue(torch.equal(torch.full((5,), -2, dtype=torch.long), z2))
  628. self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t2))
  629. dataset1_classification = AvalancheTensorDataset(
  630. tensor_x, tensor_y, tensor_z,
  631. dataset_type=AvalancheDatasetType.CLASSIFICATION)
  632. dataset2_segmentation = AvalancheDataset(
  633. dataset2, dataset_type=AvalancheDatasetType.SEGMENTATION)
  634. with self.assertRaises(ValueError):
  635. bad_concat_types = dataset1_classification + dataset2_segmentation
  636. with self.assertRaises(ValueError):
  637. bad_concat_collate = AvalancheConcatDataset(
  638. [dataset1, dataset2_segmentation], collate_fn=my_collate_fn)
  639. ok_concat_classification = dataset1_classification + dataset2
  640. self.assertEqual(AvalancheDatasetType.CLASSIFICATION,
  641. ok_concat_classification.dataset_type)
  642. ok_concat_classification2 = dataset2 + dataset1_classification
  643. self.assertEqual(AvalancheDatasetType.CLASSIFICATION,
  644. ok_concat_classification2.dataset_type)
  645. def test_avalanche_concat_dataset_recursion(self):
  646. def gen_random_tensors(n):
  647. return torch.rand(n, 3, 28, 28),\
  648. torch.randint(0, 100, (n,)),\
  649. torch.randint(0, 100, (n,))
  650. tensor_x, tensor_y, tensor_z = \
  651. gen_random_tensors(200)
  652. tensor_x2, tensor_y2, tensor_z2 = \
  653. gen_random_tensors(200)
  654. tensor_x3, tensor_y3, tensor_z3 = \
  655. gen_random_tensors(200)
  656. tensor_x4, tensor_y4, tensor_z4 = \
  657. gen_random_tensors(200)
  658. tensor_x5, tensor_y5, tensor_z5 = \
  659. gen_random_tensors(200)
  660. tensor_x6, tensor_y6, tensor_z6 = \
  661. gen_random_tensors(200)
  662. tensor_x7, tensor_y7, tensor_z7 = \
  663. gen_random_tensors(200)
  664. dataset1 = TensorDataset(tensor_x, tensor_y, tensor_z)
  665. dataset2 = AvalancheTensorDataset(tensor_x2, tensor_y2, tensor_z2,
  666. task_labels=1)
  667. dataset3 = AvalancheTensorDataset(tensor_x3, tensor_y3, tensor_z3,
  668. task_labels=2)
  669. dataset4 = AvalancheTensorDataset(tensor_x4, tensor_y4, tensor_z4,
  670. task_labels=3)
  671. dataset5 = AvalancheTensorDataset(tensor_x5, tensor_y5, tensor_z5,
  672. task_labels=4)
  673. dataset6 = AvalancheTensorDataset(tensor_x6, tensor_y6, tensor_z6)
  674. dataset7 = AvalancheTensorDataset(tensor_x7, tensor_y7, tensor_z7)
  675. # This will test recursion on both PyTorch ConcatDataset and
  676. # AvalancheConcatDataset
  677. concat = ConcatDataset([dataset1, dataset2])
  678. # Beware of the explicit task_labels=5 that *must* override the
  679. # task labels set in dataset4 and dataset5
  680. def transform_target_to_constant(ignored_target_value):
  681. return 101
  682. def transform_target_to_constant2(ignored_target_value):
  683. return 102
  684. concat2 = AvalancheConcatDataset(
  685. [dataset4, dataset5], task_labels=5,
  686. target_transform=transform_target_to_constant)
  687. concat3 = AvalancheConcatDataset(
  688. [dataset6, dataset7],
  689. target_transform=transform_target_to_constant2).freeze_transforms()
  690. concat_uut = AvalancheConcatDataset(
  691. [concat, dataset3, concat2, concat3])
  692. self.assertEqual(400, len(concat))
  693. self.assertEqual(400, len(concat2))
  694. self.assertEqual(400, len(concat3))
  695. self.assertEqual(1400, len(concat_uut))
  696. x, y, z, t = concat_uut[0]
  697. x2, y2, z2, t2 = concat_uut[200]
  698. x3, y3, z3, t3 = concat_uut[400]
  699. x4, y4, z4, t4 = concat_uut[600]
  700. x5, y5, z5, t5 = concat_uut[800]
  701. x6, y6, z6, t6 = concat_uut[1000]
  702. x7, y7, z7, t7 = concat_uut[1200]
  703. self.assertTrue(torch.equal(x, tensor_x[0]))
  704. self.assertTrue(torch.equal(y, tensor_y[0]))
  705. self.assertTrue(torch.equal(z, tensor_z[0]))
  706. self.assertEqual(0, t)
  707. self.assertTrue(torch.equal(x2, tensor_x2[0]))
  708. self.assertTrue(torch.equal(y2, tensor_y2[0]))
  709. self.assertTrue(torch.equal(z2, tensor_z2[0]))
  710. self.assertEqual(1, t2)
  711. self.assertTrue(torch.equal(x3, tensor_x3[0]))
  712. self.assertTrue(torch.equal(y3, tensor_y3[0]))
  713. self.assertTrue(torch.equal(z3, tensor_z3[0]))
  714. self.assertEqual(2, t3)
  715. self.assertTrue(torch.equal(x4, tensor_x4[0]))
  716. self.assertEqual(101, y4)
  717. self.assertTrue(torch.equal(z4, tensor_z4[0]))
  718. self.assertEqual(5, t4)
  719. self.assertTrue(torch.equal(x5, tensor_x5[0]))
  720. self.assertEqual(101, y5)
  721. self.assertTrue(torch.equal(z5, tensor_z5[0]))
  722. self.assertEqual(5, t5)
  723. self.assertTrue(torch.equal(x6, tensor_x6[0]))
  724. self.assertEqual(102, y6)
  725. self.assertTrue(torch.equal(z6, tensor_z6[0]))
  726. self.assertEqual(0, t6)
  727. self.assertTrue(torch.equal(x7, tensor_x7[0]))
  728. self.assertEqual(102, y7)
  729. self.assertTrue(torch.equal(z7, tensor_z7[0]))
  730. self.assertEqual(0, t7)
  731. def test_avalanche_pytorch_subset_recursion(self):
  732. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  733. download=True)
  734. x, y = dataset_mnist[3000]
  735. x2, y2 = dataset_mnist[1010]
  736. subset = Subset(dataset_mnist, indices=[3000, 8, 4, 1010, 12])
  737. dataset = AvalancheSubset(
  738. subset, indices=[0, 3])
  739. self.assertEqual(5, len(subset))
  740. self.assertEqual(2, len(dataset))
  741. x3, y3, t3 = dataset[0]
  742. x4, y4, t4 = dataset[1]
  743. self.assertTrue(pil_images_equal(x, x3))
  744. self.assertEqual(y, y3)
  745. self.assertEqual(0, t3)
  746. self.assertTrue(pil_images_equal(x2, x4))
  747. self.assertEqual(y2, y4)
  748. self.assertEqual(0, t4)
  749. self.assertFalse(pil_images_equal(x, x4))
  750. self.assertFalse(pil_images_equal(x2, x3))
  751. def transform_target_to_constant(ignored_target_value):
  752. return 101
  753. subset = Subset(dataset_mnist, indices=[3000, 8, 4, 1010, 12])
  754. dataset = AvalancheSubset(
  755. subset, indices=[0, 3],
  756. target_transform=transform_target_to_constant,
  757. task_labels=5)
  758. self.assertEqual(5, len(subset))
  759. self.assertEqual(2, len(dataset))
  760. x5, y5, t5 = dataset[0]
  761. x6, y6, t6 = dataset[1]
  762. self.assertTrue(pil_images_equal(x, x5))
  763. self.assertEqual(101, y5)
  764. self.assertEqual(5, t5)
  765. self.assertTrue(pil_images_equal(x2, x6))
  766. self.assertEqual(101, y6)
  767. self.assertEqual(5, t6)
  768. self.assertFalse(pil_images_equal(x, x6))
  769. self.assertFalse(pil_images_equal(x2, x5))
  770. def test_avalanche_pytorch_subset_recursion_no_indices(self):
  771. dataset_mnist = MNIST(
  772. root=expanduser("~") + "/.avalanche/data/mnist/",
  773. download=True)
  774. x, y = dataset_mnist[3000]
  775. x2, y2 = dataset_mnist[8]
  776. subset = Subset(dataset_mnist, indices=[3000, 8, 4, 1010, 12])
  777. dataset = AvalancheSubset(subset)
  778. self.assertEqual(5, len(subset))
  779. self.assertEqual(5, len(dataset))
  780. x3, y3, t3 = dataset[0]
  781. x4, y4, t4 = dataset[1]
  782. self.assertTrue(pil_images_equal(x, x3))
  783. self.assertEqual(y, y3)
  784. self.assertTrue(pil_images_equal(x2, x4))
  785. self.assertEqual(y2, y4)
  786. self.assertFalse(pil_images_equal(x, x4))
  787. self.assertFalse(pil_images_equal(x2, x3))
  788. def test_avalanche_avalanche_subset_recursion_no_indices_transform(self):
  789. dataset_mnist = MNIST(
  790. root=expanduser("~") + "/.avalanche/data/mnist/",
  791. download=True)
  792. x, y = dataset_mnist[3000]
  793. x2, y2 = dataset_mnist[8]
  794. def transform_target_to_constant(ignored_target_value):
  795. return 101
  796. def transform_target_plus_one(target_value):
  797. return target_value+1
  798. subset = AvalancheSubset(dataset_mnist,
  799. indices=[3000, 8, 4, 1010, 12],
  800. transform=ToTensor(),
  801. target_transform=transform_target_to_constant)
  802. dataset = AvalancheSubset(subset,
  803. target_transform=transform_target_plus_one)
  804. self.assertEqual(5, len(subset))
  805. self.assertEqual(5, len(dataset))
  806. x3, y3, t3 = dataset[0]
  807. x4, y4, t4 = dataset[1]
  808. self.assertIsInstance(x3, Tensor)
  809. self.assertIsInstance(x4, Tensor)
  810. self.assertTrue(torch.equal(ToTensor()(x), x3))
  811. self.assertEqual(102, y3)
  812. self.assertTrue(torch.equal(ToTensor()(x2), x4))
  813. self.assertEqual(102, y4)
  814. self.assertFalse(torch.equal(ToTensor()(x), x4))
  815. self.assertFalse(torch.equal(ToTensor()(x2), x3))
  816. def test_avalanche_avalanche_subset_recursion_transform(self):
  817. dataset_mnist = MNIST(
  818. root=expanduser("~") + "/.avalanche/data/mnist/",
  819. download=True)
  820. x, y = dataset_mnist[3000]
  821. x2, y2 = dataset_mnist[1010]
  822. def transform_target_to_constant(ignored_target_value):
  823. return 101
  824. def transform_target_plus_one(target_value):
  825. return target_value+2
  826. subset = AvalancheSubset(dataset_mnist,
  827. indices=[3000, 8, 4, 1010, 12],
  828. target_transform=transform_target_to_constant)
  829. dataset = AvalancheSubset(subset,
  830. indices=[0, 3, 1],
  831. target_transform=transform_target_plus_one)
  832. self.assertEqual(5, len(subset))
  833. self.assertEqual(3, len(dataset))
  834. x3, y3, t3 = dataset[0]
  835. x4, y4, t4 = dataset[1]
  836. self.assertTrue(pil_images_equal(x, x3))
  837. self.assertEqual(103, y3)
  838. self.assertTrue(pil_images_equal(x2, x4))
  839. self.assertEqual(103, y4)
  840. self.assertFalse(pil_images_equal(x, x4))
  841. self.assertFalse(pil_images_equal(x2, x3))
  842. def test_avalanche_avalanche_subset_recursion_frozen_transform(self):
  843. dataset_mnist = MNIST(
  844. root=expanduser("~") + "/.avalanche/data/mnist/",
  845. download=True)
  846. x, y = dataset_mnist[3000]
  847. x2, y2 = dataset_mnist[1010]
  848. def transform_target_to_constant(ignored_target_value):
  849. return 101
  850. def transform_target_plus_two(target_value):
  851. return target_value+2
  852. subset = AvalancheSubset(dataset_mnist,
  853. indices=[3000, 8, 4, 1010, 12],
  854. target_transform=transform_target_to_constant)
  855. subset = subset.freeze_transforms()
  856. dataset = AvalancheSubset(subset,
  857. indices=[0, 3, 1],
  858. target_transform=transform_target_plus_two)
  859. self.assertEqual(5, len(subset))
  860. self.assertEqual(3, len(dataset))
  861. x3, y3, t3 = dataset[0]
  862. x4, y4, t4 = dataset[1]
  863. self.assertTrue(pil_images_equal(x, x3))
  864. self.assertEqual(103, y3)
  865. self.assertTrue(pil_images_equal(x2, x4))
  866. self.assertEqual(103, y4)
  867. self.assertFalse(pil_images_equal(x, x4))
  868. self.assertFalse(pil_images_equal(x2, x3))
  869. dataset = AvalancheSubset(subset,
  870. indices=[0, 3, 1],
  871. target_transform=transform_target_plus_two)
  872. dataset = dataset.replace_transforms(None, None)
  873. x5, y5, t5 = dataset[0]
  874. x6, y6, t6 = dataset[1]
  875. self.assertTrue(pil_images_equal(x, x5))
  876. self.assertEqual(101, y5)
  877. self.assertTrue(pil_images_equal(x2, x6))
  878. self.assertEqual(101, y6)
  879. self.assertFalse(pil_images_equal(x, x6))
  880. self.assertFalse(pil_images_equal(x2, x5))
  881. def test_avalanche_avalanche_subset_recursion_modified_transforms(self):
  882. dataset_mnist = MNIST(
  883. root=expanduser("~") + "/.avalanche/data/mnist/",
  884. download=True)
  885. x, y = dataset_mnist[3000]
  886. x2, y2 = dataset_mnist[1010]
  887. def transform_target_to_constant(ignored_target_value):
  888. return 101
  889. def transform_target_to_constant2(ignored_target_value):
  890. return 102
  891. def transform_target_plus_two(target_value):
  892. return target_value+2
  893. subset = AvalancheSubset(dataset_mnist,
  894. indices=[3000, 8, 4, 1010, 12],
  895. target_transform=transform_target_to_constant)
  896. subset.target_transform = transform_target_to_constant2
  897. dataset = AvalancheSubset(subset,
  898. indices=[0, 3, 1],
  899. target_transform=transform_target_plus_two)
  900. self.assertEqual(5, len(subset))
  901. self.assertEqual(3, len(dataset))
  902. x3, y3, t3 = dataset[0]
  903. x4, y4, t4 = dataset[1]
  904. self.assertTrue(pil_images_equal(x, x3))
  905. self.assertEqual(104, y3)
  906. self.assertTrue(pil_images_equal(x2, x4))
  907. self.assertEqual(104, y4)
  908. self.assertFalse(pil_images_equal(x, x4))
  909. self.assertFalse(pil_images_equal(x2, x3))
  910. def test_avalanche_avalanche_subset_recursion_sub_class_mapping(self):
  911. dataset_mnist = MNIST(
  912. root=expanduser("~") + "/.avalanche/data/mnist/",
  913. download=True)
  914. x, y = dataset_mnist[3000]
  915. x2, y2 = dataset_mnist[1010]
  916. class_mapping = list(range(10))
  917. random.shuffle(class_mapping)
  918. subset = AvalancheSubset(dataset_mnist,
  919. indices=[3000, 8, 4, 1010, 12],
  920. class_mapping=class_mapping)
  921. dataset = AvalancheSubset(subset,
  922. indices=[0, 3, 1])
  923. self.assertEqual(5, len(subset))
  924. self.assertEqual(3, len(dataset))
  925. x3, y3, t3 = dataset[0]
  926. x4, y4, t4 = dataset[1]
  927. self.assertTrue(pil_images_equal(x, x3))
  928. expected_y3 = class_mapping[y]
  929. self.assertEqual(expected_y3, y3)
  930. self.assertTrue(pil_images_equal(x2, x4))
  931. expected_y4 = class_mapping[y2]
  932. self.assertEqual(expected_y4, y4)
  933. self.assertFalse(pil_images_equal(x, x4))
  934. self.assertFalse(pil_images_equal(x2, x3))
  935. def test_avalanche_avalanche_subset_recursion_up_class_mapping(self):
  936. dataset_mnist = MNIST(
  937. root=expanduser("~") + "/.avalanche/data/mnist/",
  938. download=True)
  939. x, y = dataset_mnist[3000]
  940. x2, y2 = dataset_mnist[1010]
  941. class_mapping = list(range(10))
  942. random.shuffle(class_mapping)
  943. subset = AvalancheSubset(dataset_mnist,
  944. indices=[3000, 8, 4, 1010, 12])
  945. dataset = AvalancheSubset(subset,
  946. indices=[0, 3, 1],
  947. class_mapping=class_mapping)
  948. self.assertEqual(5, len(subset))
  949. self.assertEqual(3, len(dataset))
  950. x3, y3, t3 = dataset[0]
  951. x4, y4, t4 = dataset[1]
  952. self.assertTrue(pil_images_equal(x, x3))
  953. expected_y3 = class_mapping[y]
  954. self.assertEqual(expected_y3, y3)
  955. self.assertTrue(pil_images_equal(x2, x4))
  956. expected_y4 = class_mapping[y2]
  957. self.assertEqual(expected_y4, y4)
  958. self.assertFalse(pil_images_equal(x, x4))
  959. self.assertFalse(pil_images_equal(x2, x3))
  960. def test_avalanche_avalanche_subset_recursion_mix_class_mapping(self):
  961. dataset_mnist = MNIST(
  962. root=expanduser("~") + "/.avalanche/data/mnist/",
  963. download=True)
  964. x, y = dataset_mnist[3000]
  965. x2, y2 = dataset_mnist[1010]
  966. class_mapping = list(range(10))
  967. class_mapping2 = list(range(10))
  968. random.shuffle(class_mapping)
  969. random.shuffle(class_mapping2)
  970. subset = AvalancheSubset(dataset_mnist,
  971. indices=[3000, 8, 4, 1010, 12],
  972. class_mapping=class_mapping)
  973. dataset = AvalancheSubset(subset,
  974. indices=[0, 3, 1],
  975. class_mapping=class_mapping2)
  976. self.assertEqual(5, len(subset))
  977. self.assertEqual(3, len(dataset))
  978. x3, y3, t3 = dataset[0]
  979. x4, y4, t4 = dataset[1]
  980. self.assertTrue(pil_images_equal(x, x3))
  981. expected_y3 = class_mapping2[class_mapping[y]]
  982. self.assertEqual(expected_y3, y3)
  983. self.assertTrue(pil_images_equal(x2, x4))
  984. expected_y4 = class_mapping2[class_mapping[y2]]
  985. self.assertEqual(expected_y4, y4)
  986. self.assertFalse(pil_images_equal(x, x4))
  987. self.assertFalse(pil_images_equal(x2, x3))
  988. def test_avalanche_avalanche_subset_concat_stack_overflow(self):
  989. d_sz = 25
  990. tensor_x = torch.rand(d_sz, 3, 28, 28)
  991. tensor_y = torch.randint(0, 10, (d_sz,))
  992. tensor_t = torch.randint(0, 10, (d_sz,))
  993. dataset = AvalancheTensorDataset(
  994. tensor_x, tensor_y, task_labels=tensor_t)
  995. dataset_hierarchy_depth = 500
  996. rolling_indices: List[List[int]] = []
  997. expect_indices: List[List[int]] = []
  998. for _ in range(dataset_hierarchy_depth):
  999. idx_permuted = list(range(d_sz))
  1000. random.shuffle(idx_permuted)
  1001. rolling_indices.append(idx_permuted)
  1002. forward_indices = range(d_sz)
  1003. expect_indices.append(list(forward_indices))
  1004. for idx in range(dataset_hierarchy_depth):
  1005. forward_indices = [forward_indices[x] for x in rolling_indices[idx]]
  1006. expect_indices.append(forward_indices)
  1007. expect_indices = list(reversed(expect_indices))
  1008. leaf = dataset
  1009. for idx in range(dataset_hierarchy_depth):
  1010. intermediate_idx_test = (dataset_hierarchy_depth - 1) - idx
  1011. subset = AvalancheSubset(leaf, indices=rolling_indices[idx])
  1012. leaf = AvalancheConcatDataset((subset, leaf))
  1013. # Regression test for #616 (second bug)
  1014. # https://github.com/ContinualAI/avalanche/issues/616#issuecomment-848852287
  1015. all_targets = []
  1016. for c_dataset in leaf._dataset_list:
  1017. all_targets += c_dataset.targets
  1018. all_targets = torch.tensor(all_targets)
  1019. for idx_internal in range(idx+1):
  1020. leaf_range = range(idx_internal * d_sz,
  1021. (idx_internal + 1) * d_sz)
  1022. permuted = expect_indices[idx_internal+intermediate_idx_test]
  1023. self.assertTrue(torch.equal(tensor_y[permuted],
  1024. all_targets[leaf_range]))
  1025. self.assertTrue(torch.equal(tensor_y,
  1026. all_targets[-d_sz:]))
  1027. self.assertEqual(d_sz * dataset_hierarchy_depth + d_sz, len(leaf))
  1028. for idx in range(dataset_hierarchy_depth):
  1029. leaf_range = range(idx*d_sz, (idx+1) * d_sz)
  1030. permuted = expect_indices[idx]
  1031. self.assertTrue(torch.equal(tensor_x[permuted],
  1032. leaf[leaf_range][0]))
  1033. self.assertTrue(torch.equal(tensor_y[permuted],
  1034. leaf[leaf_range][1]))
  1035. self.assertTrue(torch.equal(tensor_y[permuted],
  1036. torch.tensor(leaf.targets)[leaf_range]))
  1037. self.assertTrue(torch.equal(tensor_t[permuted],
  1038. leaf[leaf_range][2]))
  1039. self.assertTrue(torch.equal(tensor_x,
  1040. leaf[d_sz*dataset_hierarchy_depth:][0]))
  1041. self.assertTrue(torch.equal(tensor_y,
  1042. leaf[d_sz*dataset_hierarchy_depth:][1]))
  1043. self.assertTrue(torch.equal(
  1044. tensor_y,
  1045. torch.tensor(leaf.targets)[d_sz * dataset_hierarchy_depth:]))
  1046. self.assertTrue(torch.equal(tensor_t,
  1047. leaf[d_sz*dataset_hierarchy_depth:][2]))
  1048. class TransformationSubsetTests(unittest.TestCase):
  1049. def test_avalanche_subset_transform(self):
  1050. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  1051. download=True)
  1052. x, y = dataset_mnist[0]
  1053. dataset = AvalancheSubset(dataset_mnist, transform=ToTensor())
  1054. x2, y2, t2 = dataset[0]
  1055. self.assertIsInstance(x2, Tensor)
  1056. self.assertIsInstance(y2, int)
  1057. self.assertIsInstance(t2, int)
  1058. self.assertTrue(torch.equal(ToTensor()(x), x2))
  1059. self.assertEqual(y, y2)
  1060. self.assertEqual(0, t2)
  1061. def test_avalanche_subset_composition(self):
  1062. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  1063. download=True, transform=RandomCrop(16))
  1064. x, y = dataset_mnist[0]
  1065. self.assertIsInstance(x, Image)
  1066. self.assertEqual([x.width, x.height], [16, 16])
  1067. self.assertIsInstance(y, int)
  1068. dataset = AvalancheSubset(
  1069. dataset_mnist, transform=ToTensor(),
  1070. target_transform=lambda target: -1)
  1071. x2, y2, t2 = dataset[0]
  1072. self.assertIsInstance(x2, Tensor)
  1073. self.assertEqual(x2.shape, (1, 16, 16))
  1074. self.assertIsInstance(y2, int)
  1075. self.assertIsInstance(t2, int)
  1076. self.assertEqual(y2, -1)
  1077. self.assertEqual(0, t2)
  1078. def test_avalanche_subset_indices(self):
  1079. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  1080. download=True)
  1081. x, y = dataset_mnist[1000]
  1082. x2, y2 = dataset_mnist[1007]
  1083. dataset = AvalancheSubset(
  1084. dataset_mnist, indices=[1000, 1007])
  1085. x3, y3, t3 = dataset[0]
  1086. x4, y4, t4 = dataset[1]
  1087. self.assertTrue(pil_images_equal(x, x3))
  1088. self.assertEqual(y, y3)
  1089. self.assertTrue(pil_images_equal(x2, x4))
  1090. self.assertEqual(y2, y4)
  1091. self.assertFalse(pil_images_equal(x, x4))
  1092. self.assertFalse(pil_images_equal(x2, x3))
  1093. def test_avalanche_subset_mapping(self):
  1094. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  1095. download=True)
  1096. _, y = dataset_mnist[1000]
  1097. mapping = list(range(10))
  1098. other_classes = list(mapping)
  1099. other_classes.remove(y)
  1100. swap_y = random.choice(other_classes)
  1101. mapping[y] = swap_y
  1102. mapping[swap_y] = y
  1103. dataset = AvalancheSubset(dataset_mnist, class_mapping=mapping)
  1104. _, y2, _ = dataset[1000]
  1105. self.assertEqual(y2, swap_y)
  1106. def test_avalanche_subset_uniform_task_labels(self):
  1107. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  1108. download=True)
  1109. x, y = dataset_mnist[1000]
  1110. x2, y2 = dataset_mnist[1007]
  1111. # First, test by passing len(task_labels) == len(dataset_mnist)
  1112. dataset = AvalancheSubset(
  1113. dataset_mnist, indices=[1000, 1007],
  1114. task_labels=[1] * len(dataset_mnist))
  1115. x3, y3, t3 = dataset[0]
  1116. x4, y4, t4 = dataset[1]
  1117. self.assertEqual(y, y3)
  1118. self.assertEqual(1, t3)
  1119. self.assertEqual(y2, y4)
  1120. self.assertEqual(1, t4)
  1121. # Secondly, test by passing len(task_labels) == len(indices)
  1122. dataset = AvalancheSubset(
  1123. dataset_mnist, indices=[1000, 1007],
  1124. task_labels=[1, 1])
  1125. x3, y3, t3 = dataset[0]
  1126. x4, y4, t4 = dataset[1]
  1127. self.assertEqual(y, y3)
  1128. self.assertEqual(1, t3)
  1129. self.assertEqual(y2, y4)
  1130. self.assertEqual(1, t4)
  1131. def test_avalanche_subset_mixed_task_labels(self):
  1132. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  1133. download=True)
  1134. x, y = dataset_mnist[1000]
  1135. x2, y2 = dataset_mnist[1007]
  1136. full_task_labels = [1] * len(dataset_mnist)
  1137. full_task_labels[1000] = 2
  1138. # First, test by passing len(task_labels) == len(dataset_mnist)
  1139. dataset = AvalancheSubset(
  1140. dataset_mnist, indices=[1000, 1007],
  1141. task_labels=full_task_labels)
  1142. x3, y3, t3 = dataset[0]
  1143. x4, y4, t4 = dataset[1]
  1144. self.assertEqual(y, y3)
  1145. self.assertEqual(2, t3)
  1146. self.assertEqual(y2, y4)
  1147. self.assertEqual(1, t4)
  1148. # Secondly, test by passing len(task_labels) == len(indices)
  1149. dataset = AvalancheSubset(
  1150. dataset_mnist, indices=[1000, 1007],
  1151. task_labels=[3, 5])
  1152. x3, y3, t3 = dataset[0]
  1153. x4, y4, t4 = dataset[1]
  1154. self.assertEqual(y, y3)
  1155. self.assertEqual(3, t3)
  1156. self.assertEqual(y2, y4)
  1157. self.assertEqual(5, t4)
  1158. def test_avalanche_subset_task_labels_inheritance(self):
  1159. dataset_mnist = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  1160. download=True)
  1161. random_task_labels = [random.randint(0, 10)
  1162. for _ in range(len(dataset_mnist))]
  1163. dataset_orig = AvalancheDataset(dataset_mnist, transform=ToTensor(),
  1164. task_labels=random_task_labels)
  1165. dataset_child = AvalancheSubset(dataset_orig,
  1166. indices=[1000, 1007])
  1167. _, _, t2 = dataset_orig[1000]
  1168. _, _, t5 = dataset_orig[1007]
  1169. _, _, t3 = dataset_child[0]
  1170. _, _, t6 = dataset_child[1]
  1171. self.assertEqual(random_task_labels[1000], t2)
  1172. self.assertEqual(random_task_labels[1007], t5)
  1173. self.assertEqual(random_task_labels[1000], t3)
  1174. self.assertEqual(random_task_labels[1007], t6)
  1175. self.assertListEqual(random_task_labels,
  1176. list(dataset_orig.targets_task_labels))
  1177. self.assertListEqual([random_task_labels[1000],
  1178. random_task_labels[1007]],
  1179. list(dataset_child.targets_task_labels))
  1180. def test_avalanche_subset_collate_fn_inheritance(self):
  1181. tensor_x = torch.rand(200, 3, 28, 28)
  1182. tensor_y = torch.randint(0, 100, (200,))
  1183. tensor_z = torch.randint(0, 100, (200,))
  1184. def my_collate_fn(patterns):
  1185. x_values = torch.stack([pat[0] for pat in patterns], 0)
  1186. y_values = torch.tensor([pat[1] for pat in patterns]) + 1
  1187. z_values = torch.tensor([-1 for _ in patterns])
  1188. t_values = torch.tensor([pat[3] for pat in patterns])
  1189. return x_values, y_values, z_values, t_values
  1190. def my_collate_fn2(patterns):
  1191. x_values = torch.stack([pat[0] for pat in patterns], 0)
  1192. y_values = torch.tensor([pat[1] for pat in patterns]) + 2
  1193. z_values = torch.tensor([-2 for _ in patterns])
  1194. t_values = torch.tensor([pat[3] for pat in patterns])
  1195. return x_values, y_values, z_values, t_values
  1196. whole_dataset = TensorDataset(tensor_x, tensor_y, tensor_z)
  1197. dataset = AvalancheDataset(whole_dataset, collate_fn=my_collate_fn)
  1198. inherited = AvalancheSubset(dataset, indices=list(range(5, 150)),
  1199. collate_fn=my_collate_fn2) # Ok
  1200. x, y, z, t = inherited[0:5]
  1201. self.assertIsInstance(x, Tensor)
  1202. self.assertTrue(torch.equal(tensor_x[5:10], x))
  1203. self.assertTrue(torch.equal(tensor_y[5:10] + 2, y))
  1204. self.assertTrue(torch.equal(torch.full((5,), -2, dtype=torch.long), z))
  1205. self.assertTrue(torch.equal(torch.zeros(5, dtype=torch.long), t))
  1206. classification_dataset = AvalancheDataset(
  1207. whole_dataset, dataset_type=AvalancheDatasetType.CLASSIFICATION)
  1208. with self.assertRaises(ValueError):
  1209. bad_inherited = AvalancheSubset(
  1210. classification_dataset, indices=list(range(5, 150)),
  1211. collate_fn=my_collate_fn)
  1212. ok_inherited_classification = AvalancheSubset(
  1213. classification_dataset, indices=list(range(5, 150)))
  1214. self.assertEqual(AvalancheDatasetType.CLASSIFICATION,
  1215. ok_inherited_classification.dataset_type)
  1216. class TransformationTensorDatasetTests(unittest.TestCase):
  1217. def test_tensor_dataset_helper_tensor_y(self):
  1218. train_exps = [[torch.rand(50, 32, 32), torch.randint(0, 100, (50,))]
  1219. for _ in range(5)]
  1220. test_exps = [[torch.rand(23, 32, 32), torch.randint(0, 100, (23,))]
  1221. for _ in range(5)]
  1222. cl_benchmark = create_generic_benchmark_from_tensor_lists(
  1223. train_tensors=train_exps, test_tensors=test_exps,
  1224. task_labels=[0] * 5)
  1225. self.assertEqual(5, len(cl_benchmark.train_stream))
  1226. self.assertEqual(5, len(cl_benchmark.test_stream))
  1227. self.assertEqual(5, cl_benchmark.n_experiences)
  1228. for exp_id in range(cl_benchmark.n_experiences):
  1229. benchmark_train_x, benchmark_train_y, _ = \
  1230. load_all_dataset(cl_benchmark.train_stream[exp_id].dataset)
  1231. benchmark_test_x, benchmark_test_y, _ = \
  1232. load_all_dataset(cl_benchmark.test_stream[exp_id].dataset)
  1233. self.assertTrue(torch.all(torch.eq(
  1234. train_exps[exp_id][0],
  1235. benchmark_train_x)))
  1236. self.assertTrue(torch.all(torch.eq(
  1237. train_exps[exp_id][1],
  1238. benchmark_train_y)))
  1239. self.assertSequenceEqual(
  1240. train_exps[exp_id][1].tolist(),
  1241. cl_benchmark.train_stream[exp_id].dataset.targets)
  1242. self.assertEqual(0, cl_benchmark.train_stream[exp_id].task_label)
  1243. self.assertTrue(torch.all(torch.eq(
  1244. test_exps[exp_id][0],
  1245. benchmark_test_x)))
  1246. self.assertTrue(torch.all(torch.eq(
  1247. test_exps[exp_id][1],
  1248. benchmark_test_y)))
  1249. self.assertSequenceEqual(
  1250. test_exps[exp_id][1].tolist(),
  1251. cl_benchmark.test_stream[exp_id].dataset.targets)
  1252. self.assertEqual(0, cl_benchmark.test_stream[exp_id].task_label)
  1253. def test_tensor_dataset_helper_list_y(self):
  1254. train_exps = [(torch.rand(50, 32, 32),
  1255. torch.randint(0, 100, (50,)).tolist()) for _ in range(5)]
  1256. test_exps = [(torch.rand(23, 32, 32),
  1257. torch.randint(0, 100, (23,)).tolist()) for _ in range(5)]
  1258. cl_benchmark = create_generic_benchmark_from_tensor_lists(
  1259. train_tensors=train_exps, test_tensors=test_exps,
  1260. task_labels=[0] * 5)
  1261. self.assertEqual(5, len(cl_benchmark.train_stream))
  1262. self.assertEqual(5, len(cl_benchmark.test_stream))
  1263. self.assertEqual(5, cl_benchmark.n_experiences)
  1264. for exp_id in range(cl_benchmark.n_experiences):
  1265. benchmark_train_x, benchmark_train_y, _ = \
  1266. load_all_dataset(cl_benchmark.train_stream[exp_id].dataset)
  1267. benchmark_test_x, benchmark_test_y, _ = \
  1268. load_all_dataset(cl_benchmark.test_stream[exp_id].dataset)
  1269. self.assertTrue(torch.all(torch.eq(
  1270. train_exps[exp_id][0],
  1271. benchmark_train_x)))
  1272. self.assertSequenceEqual(
  1273. train_exps[exp_id][1],
  1274. benchmark_train_y.tolist())
  1275. self.assertSequenceEqual(
  1276. train_exps[exp_id][1],
  1277. cl_benchmark.train_stream[exp_id].dataset.targets)
  1278. self.assertEqual(0, cl_benchmark.train_stream[exp_id].task_label)
  1279. self.assertTrue(torch.all(torch.eq(
  1280. test_exps[exp_id][0],
  1281. benchmark_test_x)))
  1282. self.assertSequenceEqual(
  1283. test_exps[exp_id][1],
  1284. benchmark_test_y.tolist())
  1285. self.assertSequenceEqual(
  1286. test_exps[exp_id][1],
  1287. cl_benchmark.test_stream[exp_id].dataset.targets)
  1288. self.assertEqual(0, cl_benchmark.test_stream[exp_id].task_label)
  1289. class AvalancheDatasetTransformOpsTests(unittest.TestCase):
  1290. def test_avalanche_inherit_groups(self):
  1291. original_dataset = MNIST(
  1292. root=expanduser("~") + "/.avalanche/data/mnist/", download=True
  1293. )
  1294. def plus_one_target(target):
  1295. return target+1
  1296. transform_groups = dict(
  1297. train=(ToTensor(), None),
  1298. eval=(None, plus_one_target)
  1299. )
  1300. x, y = original_dataset[0]
  1301. dataset = AvalancheDataset(original_dataset,
  1302. transform_groups=transform_groups)
  1303. x2, y2, _ = dataset[0]
  1304. self.assertIsInstance(x2, Tensor)
  1305. self.assertIsInstance(y2, int)
  1306. self.assertTrue(torch.equal(ToTensor()(x), x2))
  1307. self.assertEqual(y, y2)
  1308. dataset_eval = dataset.eval()
  1309. x3, y3, _ = dataset_eval[0]
  1310. self.assertIsInstance(x3, PIL.Image.Image)
  1311. self.assertIsInstance(y3, int)
  1312. self.assertEqual(y+1, y3)
  1313. # Regression test for #565
  1314. dataset_inherit = AvalancheDataset(dataset_eval)
  1315. x4, y4, _ = dataset_inherit[0]
  1316. self.assertIsInstance(x4, PIL.Image.Image)
  1317. self.assertIsInstance(y4, int)
  1318. self.assertEqual(y + 1, y4)
  1319. # Regression test for #566
  1320. dataset_sub_train = AvalancheSubset(dataset)
  1321. dataset_sub_eval = dataset_sub_train.eval()
  1322. dataset_sub = AvalancheSubset(dataset_sub_eval, indices=[0])
  1323. x5, y5, _ = dataset_sub[0]
  1324. self.assertIsInstance(x5, PIL.Image.Image)
  1325. self.assertIsInstance(y5, int)
  1326. self.assertEqual(y + 1, y5)
  1327. # End regression tests
  1328. concat_dataset = AvalancheConcatDataset([dataset_sub_eval, dataset_sub])
  1329. x6, y6, _ = concat_dataset[0]
  1330. self.assertIsInstance(x6, PIL.Image.Image)
  1331. self.assertIsInstance(y6, int)
  1332. self.assertEqual(y + 1, y6)
  1333. concat_dataset_no_inherit_initial = \
  1334. AvalancheConcatDataset([dataset_sub_eval, dataset])
  1335. x7, y7, _ = concat_dataset_no_inherit_initial[0]
  1336. self.assertIsInstance(x7, Tensor)
  1337. self.assertIsInstance(y7, int)
  1338. self.assertEqual(y, y7)
  1339. def test_freeze_transforms(self):
  1340. original_dataset = MNIST(
  1341. root=expanduser("~") + "/.avalanche/data/mnist/", download=True
  1342. )
  1343. x, y = original_dataset[0]
  1344. dataset = AvalancheDataset(original_dataset, transform=ToTensor())
  1345. dataset_frozen = dataset.freeze_transforms()
  1346. dataset_frozen.transform = None
  1347. x2, y2, _ = dataset_frozen[0]
  1348. self.assertIsInstance(x2, Tensor)
  1349. self.assertIsInstance(y2, int)
  1350. self.assertTrue(torch.equal(ToTensor()(x), x2))
  1351. self.assertEqual(y, y2)
  1352. dataset.transform = None
  1353. x2, y2, _ = dataset[0]
  1354. self.assertIsInstance(x2, Image)
  1355. x2, y2, _ = dataset_frozen[0]
  1356. self.assertIsInstance(x2, Tensor)
  1357. def test_freeze_transforms_chain(self):
  1358. original_dataset = MNIST(
  1359. root=expanduser("~") + "/.avalanche/data/mnist/", download=True,
  1360. transform=ToTensor()
  1361. )
  1362. x, *_ = original_dataset[0]
  1363. self.assertIsInstance(x, Tensor)
  1364. dataset_transform = AvalancheDataset(original_dataset,
  1365. transform=ToPILImage())
  1366. x, *_ = dataset_transform[0]
  1367. self.assertIsInstance(x, Image)
  1368. dataset_frozen = dataset_transform.freeze_transforms()
  1369. x2, *_ = dataset_frozen[0]
  1370. self.assertIsInstance(x2, Image)
  1371. dataset_transform.transform = None
  1372. x2, *_ = dataset_transform[0]
  1373. self.assertIsInstance(x2, Tensor)
  1374. dataset_frozen.transform = ToTensor()
  1375. x2, *_ = dataset_frozen[0]
  1376. self.assertIsInstance(x2, Tensor)
  1377. dataset_frozen2 = dataset_frozen.freeze_transforms()
  1378. x2, *_ = dataset_frozen2[0]
  1379. self.assertIsInstance(x2, Tensor)
  1380. dataset_frozen.transform = None
  1381. x2, *_ = dataset_frozen2[0]
  1382. self.assertIsInstance(x2, Tensor)
  1383. x2, *_ = dataset_frozen[0]
  1384. self.assertIsInstance(x2, Image)
  1385. def test_add_transforms(self):
  1386. original_dataset = MNIST(
  1387. root=expanduser("~") + "/.avalanche/data/mnist/", download=True
  1388. )
  1389. x, _ = original_dataset[0]
  1390. dataset = AvalancheDataset(original_dataset, transform=ToTensor())
  1391. dataset_added = dataset.add_transforms(ToPILImage())
  1392. x2, *_ = dataset[0]
  1393. x3, *_ = dataset_added[0]
  1394. self.assertIsInstance(x, Image)
  1395. self.assertIsInstance(x2, Tensor)
  1396. self.assertIsInstance(x3, Image)
  1397. def test_add_transforms_chain(self):
  1398. original_dataset = MNIST(
  1399. root=expanduser("~") + "/.avalanche/data/mnist/", download=True
  1400. )
  1401. x, _ = original_dataset[0]
  1402. dataset = AvalancheDataset(original_dataset, transform=ToTensor())
  1403. dataset_added = AvalancheDataset(dataset, transform=ToPILImage())
  1404. x2, *_ = dataset[0]
  1405. x3, *_ = dataset_added[0]
  1406. self.assertIsInstance(x, Image)
  1407. self.assertIsInstance(x2, Tensor)
  1408. self.assertIsInstance(x3, Image)
  1409. def test_transforms_freeze_add_mix(self):
  1410. original_dataset = MNIST(
  1411. root=expanduser("~") + "/.avalanche/data/mnist/", download=True)
  1412. x, _ = original_dataset[0]
  1413. dataset = AvalancheDataset(original_dataset, transform=ToTensor())
  1414. dataset_frozen = dataset.freeze_transforms()
  1415. dataset_added = dataset_frozen.add_transforms(ToPILImage())
  1416. self.assertEqual(None, dataset_frozen.transform)
  1417. x2, *_ = dataset[0]
  1418. x3, *_ = dataset_added[0]
  1419. self.assertIsInstance(x, Image)
  1420. self.assertIsInstance(x2, Tensor)
  1421. self.assertIsInstance(x3, Image)
  1422. dataset_frozen = dataset_added.freeze_transforms()
  1423. dataset_added.transform = None
  1424. x4, *_ = dataset_frozen[0]
  1425. x5, *_ = dataset_added[0]
  1426. self.assertIsInstance(x4, Image)
  1427. self.assertIsInstance(x5, Tensor)
  1428. def test_replace_transforms(self):
  1429. original_dataset = MNIST(
  1430. root=expanduser("~") + "/.avalanche/data/mnist/", download=True)
  1431. x, y = original_dataset[0]
  1432. dataset = AvalancheDataset(original_dataset, transform=ToTensor())
  1433. x2, *_ = dataset[0]
  1434. dataset_reset = dataset.replace_transforms(None, None)
  1435. x3, *_ = dataset_reset[0]
  1436. self.assertIsInstance(x, Image)
  1437. self.assertIsInstance(x2, Tensor)
  1438. self.assertIsInstance(x3, Image)
  1439. dataset_reset.transform = ToTensor()
  1440. x4, *_ = dataset_reset[0]
  1441. self.assertIsInstance(x4, Tensor)
  1442. dataset_reset.replace_transforms(None, None)
  1443. x5, *_ = dataset_reset[0]
  1444. self.assertIsInstance(x5, Tensor)
  1445. dataset_other = AvalancheDataset(dataset_reset)
  1446. dataset_other = dataset_other.replace_transforms(None, lambda l: l + 1)
  1447. _, y6, _ = dataset_other[0]
  1448. self.assertEqual(y+1, y6)
  1449. def test_transforms_replace_freeze_mix(self):
  1450. original_dataset = MNIST(
  1451. root=expanduser("~") + "/.avalanche/data/mnist/", download=True)
  1452. x, _ = original_dataset[0]
  1453. dataset = AvalancheDataset(original_dataset, transform=ToTensor())
  1454. x2, *_ = dataset[0]
  1455. dataset_reset = dataset.replace_transforms(None, None)
  1456. x3, *_ = dataset_reset[0]
  1457. self.assertIsInstance(x, Image)
  1458. self.assertIsInstance(x2, Tensor)
  1459. self.assertIsInstance(x3, Image)
  1460. dataset_frozen = dataset.freeze_transforms()
  1461. x4, *_ = dataset_frozen[0]
  1462. self.assertIsInstance(x4, Tensor)
  1463. dataset_frozen_reset = dataset_frozen.replace_transforms(None, None)
  1464. x5, *_ = dataset_frozen_reset[0]
  1465. self.assertIsInstance(x5, Tensor)
  1466. def test_transforms_groups_base_usage(self):
  1467. original_dataset = MNIST(
  1468. root=expanduser("~") + "/.avalanche/data/mnist/", download=True)
  1469. dataset = AvalancheDataset(
  1470. original_dataset,
  1471. transform_groups=dict(train=(ToTensor(), None),
  1472. eval=(None, Lambda(lambda t: float(t)))))
  1473. x, y, _ = dataset[0]
  1474. self.assertIsInstance(x, Tensor)
  1475. self.assertIsInstance(y, int)
  1476. dataset_test = dataset.eval()
  1477. x2, y2, _ = dataset_test[0]
  1478. x3, y3, _ = dataset[0]
  1479. self.assertIsInstance(x2, Image)
  1480. self.assertIsInstance(y2, float)
  1481. self.assertIsInstance(x3, Tensor)
  1482. self.assertIsInstance(y3, int)
  1483. dataset_train = dataset.train()
  1484. dataset.transform = None
  1485. x4, y4, _ = dataset_train[0]
  1486. x5, y5, _ = dataset[0]
  1487. self.assertIsInstance(x4, Tensor)
  1488. self.assertIsInstance(y4, int)
  1489. self.assertIsInstance(x5, Image)
  1490. self.assertIsInstance(y5, int)
  1491. def test_transforms_groups_constructor_error(self):
  1492. original_dataset = MNIST(
  1493. root=expanduser("~") + "/.avalanche/data/mnist/", download=True)
  1494. with self.assertRaises(Exception):
  1495. # Test tuple has only one element
  1496. dataset = AvalancheDataset(
  1497. original_dataset,
  1498. transform_groups=dict(train=(ToTensor(), None),
  1499. eval=(Lambda(lambda t: float(t)))))
  1500. with self.assertRaises(Exception):
  1501. # Test is not a tuple has only one element
  1502. dataset = AvalancheDataset(
  1503. original_dataset,
  1504. transform_groups=dict(train=(ToTensor(), None),
  1505. eval=[None, Lambda(lambda t: float(t))]))
  1506. with self.assertRaises(Exception):
  1507. # Train is None
  1508. dataset = AvalancheDataset(
  1509. original_dataset,
  1510. transform_groups=dict(train=None,
  1511. eval=(None, Lambda(lambda t: float(t)))))
  1512. with self.assertRaises(Exception):
  1513. # transform_groups is not a dictionary
  1514. dataset = AvalancheDataset(
  1515. original_dataset,
  1516. transform_groups='Hello world!')
  1517. def test_transforms_groups_alternative_default_group(self):
  1518. original_dataset = MNIST(
  1519. root=expanduser("~") + "/.avalanche/data/mnist/", download=True)
  1520. dataset = AvalancheDataset(
  1521. original_dataset,
  1522. transform_groups=dict(train=(ToTensor(), None), eval=(None, None)),
  1523. initial_transform_group='eval')
  1524. x, *_ = dataset[0]
  1525. self.assertIsInstance(x, Image)
  1526. dataset_test = dataset.eval()
  1527. x2, *_ = dataset_test[0]
  1528. x3, *_ = dataset[0]
  1529. self.assertIsInstance(x2, Image)
  1530. self.assertIsInstance(x3, Image)
  1531. def test_transforms_groups_partial_constructor(self):
  1532. original_dataset = MNIST(
  1533. root=expanduser("~") + "/.avalanche/data/mnist/", download=True)
  1534. dataset = AvalancheDataset(
  1535. original_dataset, transform_groups=dict(train=(ToTensor(), None)))
  1536. x, *_ = dataset[0]
  1537. self.assertIsInstance(x, Tensor)
  1538. dataset = dataset.eval()
  1539. x2, *_ = dataset[0]
  1540. self.assertIsInstance(x2, Tensor)
  1541. def test_transforms_groups_multiple_groups(self):
  1542. original_dataset = MNIST(
  1543. root=expanduser("~") + "/.avalanche/data/mnist/", download=True)
  1544. dataset = AvalancheDataset(
  1545. original_dataset,
  1546. transform_groups=dict(
  1547. train=(ToTensor(), None),
  1548. eval=(None, None),
  1549. other=(Compose([ToTensor(),
  1550. Lambda(lambda tensor: tensor.numpy())]), None)))
  1551. x, *_ = dataset[0]
  1552. self.assertIsInstance(x, Tensor)
  1553. dataset = dataset.eval()
  1554. x2, *_ = dataset[0]
  1555. self.assertIsInstance(x2, Image)
  1556. dataset = dataset.with_transforms('other')
  1557. x3, *_ = dataset[0]
  1558. self.assertIsInstance(x3, np.ndarray)
  1559. def test_transforms_add_group(self):
  1560. original_dataset = MNIST(
  1561. root=expanduser("~") + "/.avalanche/data/mnist/", download=True)
  1562. dataset = AvalancheDataset(original_dataset)
  1563. with self.assertRaises(Exception):
  1564. # Can't add existing groups
  1565. dataset = dataset.add_transforms_group('train', ToTensor(), None)
  1566. with self.assertRaises(Exception):
  1567. # Can't add group with bad names (must be str)
  1568. dataset = dataset.add_transforms_group(123, ToTensor(), None)
  1569. # Can't add group with bad names (must be str)
  1570. dataset = dataset.add_transforms_group('other', ToTensor(), None)
  1571. dataset_other = dataset.with_transforms('other')
  1572. x, *_ = dataset[0]
  1573. x2, *_ = dataset_other[0]
  1574. self.assertIsInstance(x, Image)
  1575. self.assertIsInstance(x2, Tensor)
  1576. dataset_other2 = AvalancheDataset(dataset_other)
  1577. # Checks that the other group is used on dataset_other2
  1578. x3, *_ = dataset_other2[0]
  1579. self.assertIsInstance(x3, Tensor)
  1580. with self.assertRaises(Exception):
  1581. # Can't add group if it already exists
  1582. dataset_other2 = dataset_other2.add_transforms_group(
  1583. 'other', ToTensor(), None)
  1584. # Check that the above failed method didn't change the 'other' group
  1585. x4, *_ = dataset_other2[0]
  1586. self.assertIsInstance(x4, Tensor)
  1587. def test_transformation_concat_dataset(self):
  1588. original_dataset = MNIST(
  1589. root=expanduser("~") + "/.avalanche/data/mnist/", download=True)
  1590. original_dataset2 = MNIST(
  1591. root=expanduser("~") + "/.avalanche/data/mnist/", download=True)
  1592. dataset = AvalancheConcatDataset([original_dataset,
  1593. original_dataset2])
  1594. self.assertEqual(len(original_dataset) + len(original_dataset2),
  1595. len(dataset))
  1596. def test_transformation_concat_dataset_groups(self):
  1597. original_dataset = AvalancheDataset(
  1598. MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  1599. download=True),
  1600. transform_groups=dict(eval=(None, None), train=(ToTensor(), None)))
  1601. original_dataset2 = AvalancheDataset(
  1602. MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  1603. download=True),
  1604. transform_groups=dict(train=(None, None), eval=(ToTensor(), None)))
  1605. dataset = AvalancheConcatDataset([original_dataset,
  1606. original_dataset2])
  1607. self.assertEqual(len(original_dataset) + len(original_dataset2),
  1608. len(dataset))
  1609. x, *_ = dataset[0]
  1610. x2, *_ = dataset[len(original_dataset)]
  1611. self.assertIsInstance(x, Tensor)
  1612. self.assertIsInstance(x2, Image)
  1613. dataset = dataset.eval()
  1614. x3, *_ = dataset[0]
  1615. x4, *_ = dataset[len(original_dataset)]
  1616. self.assertIsInstance(x3, Image)
  1617. self.assertIsInstance(x4, Tensor)
  1618. if __name__ == '__main__':
  1619. unittest.main()