test_ni_sit_scenario.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import unittest
  2. from os.path import expanduser
  3. import torch
  4. from torchvision.datasets import MNIST
  5. from avalanche.benchmarks.scenarios import NIExperience, GenericScenarioStream
  6. from avalanche.benchmarks.utils import AvalancheSubset
  7. from avalanche.benchmarks.scenarios.new_classes.nc_utils import \
  8. make_nc_transformation_subset
  9. from avalanche.benchmarks import ni_benchmark
  10. class NISITTests(unittest.TestCase):
  11. def test_ni_sit_single_dataset(self):
  12. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  13. train=True, download=True)
  14. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  15. train=False, download=True)
  16. my_ni_benchmark = ni_benchmark(
  17. mnist_train, mnist_test, 5, shuffle=True, seed=1234,
  18. balance_experiences=True)
  19. self.assertEqual(5, my_ni_benchmark.n_experiences)
  20. self.assertEqual(10, my_ni_benchmark.n_classes)
  21. for batch_id in range(5):
  22. self.assertEqual(
  23. 10,
  24. len(my_ni_benchmark.classes_in_experience['train'][batch_id])
  25. )
  26. _, unique_count = torch.unique(torch.as_tensor(mnist_train.targets),
  27. return_counts=True)
  28. min_batch_size = torch.sum(unique_count //
  29. my_ni_benchmark.n_experiences).item()
  30. max_batch_size = min_batch_size + my_ni_benchmark.n_classes
  31. pattern_count = 0
  32. batch_info: NIExperience
  33. for batch_id, batch_info in enumerate(my_ni_benchmark.train_stream):
  34. cur_train_set = batch_info.dataset
  35. t = batch_info.task_label
  36. self.assertEqual(0, t)
  37. self.assertEqual(batch_id, batch_info.current_experience)
  38. self.assertGreaterEqual(len(cur_train_set), min_batch_size)
  39. self.assertLessEqual(len(cur_train_set), max_batch_size)
  40. pattern_count += len(cur_train_set)
  41. self.assertEqual(len(mnist_train), pattern_count)
  42. self.assertEqual(1, len(my_ni_benchmark.test_stream))
  43. pattern_count = 0
  44. for batch_id, batch_info in enumerate(my_ni_benchmark.test_stream):
  45. cur_test_set = batch_info.dataset
  46. t = batch_info.task_label
  47. self.assertEqual(0, t)
  48. self.assertEqual(batch_id, batch_info.current_experience)
  49. pattern_count += len(cur_test_set)
  50. self.assertEqual(len(mnist_test), pattern_count)
  51. def test_ni_sit_single_dataset_fixed_assignment(self):
  52. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  53. train=True, download=True)
  54. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  55. train=False, download=True)
  56. ni_benchmark_reference = ni_benchmark(
  57. mnist_train, mnist_test, 5, shuffle=True, seed=1234)
  58. reference_assignment = ni_benchmark_reference.\
  59. train_exps_patterns_assignment
  60. my_ni_benchmark = ni_benchmark(
  61. mnist_train, mnist_test, 5, shuffle=True, seed=4321,
  62. fixed_exp_assignment=reference_assignment)
  63. self.assertEqual(
  64. ni_benchmark_reference.n_experiences, my_ni_benchmark.n_experiences
  65. )
  66. self.assertEqual(ni_benchmark_reference.train_exps_patterns_assignment,
  67. my_ni_benchmark.train_exps_patterns_assignment)
  68. self.assertEqual(ni_benchmark_reference.exp_structure,
  69. my_ni_benchmark.exp_structure)
  70. def test_ni_sit_single_dataset_reproducibility_data(self):
  71. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  72. train=True, download=True)
  73. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  74. train=False, download=True)
  75. ni_benchmark_reference = ni_benchmark(
  76. mnist_train, mnist_test, 5, shuffle=True, seed=1234)
  77. rep_data = ni_benchmark_reference.get_reproducibility_data()
  78. my_ni_benchmark = ni_benchmark(
  79. mnist_train, mnist_test, 0, reproducibility_data=rep_data)
  80. self.assertEqual(
  81. ni_benchmark_reference.n_experiences, my_ni_benchmark.n_experiences
  82. )
  83. self.assertEqual(ni_benchmark_reference.train_exps_patterns_assignment,
  84. my_ni_benchmark.train_exps_patterns_assignment)
  85. self.assertEqual(ni_benchmark_reference.exp_structure,
  86. my_ni_benchmark.exp_structure)
  87. def test_ni_sit_multi_dataset_merge(self):
  88. split_mapping = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
  89. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  90. train=True, download=True)
  91. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  92. train=False, download=True)
  93. train_part1 = make_nc_transformation_subset(
  94. mnist_train, None, None, range(5))
  95. train_part2 = make_nc_transformation_subset(
  96. mnist_train, None, None, range(5, 10))
  97. train_part2 = AvalancheSubset(
  98. train_part2, class_mapping=split_mapping)
  99. test_part1 = make_nc_transformation_subset(
  100. mnist_test, None, None, range(5))
  101. test_part2 = make_nc_transformation_subset(
  102. mnist_test, None, None, range(5, 10))
  103. test_part2 = AvalancheSubset(test_part2,
  104. class_mapping=split_mapping)
  105. my_ni_benchmark = ni_benchmark(
  106. [train_part1, train_part2], [test_part1, test_part2], 5,
  107. shuffle=True, seed=1234, balance_experiences=True)
  108. self.assertEqual(5, my_ni_benchmark.n_experiences)
  109. self.assertEqual(10, my_ni_benchmark.n_classes)
  110. for batch_id in range(5):
  111. self.assertEqual(
  112. 10,
  113. len(my_ni_benchmark.classes_in_experience['train'][batch_id]))
  114. all_classes = set()
  115. for batch_id in range(5):
  116. all_classes.update(
  117. my_ni_benchmark.classes_in_experience['train'][batch_id])
  118. self.assertEqual(10, len(all_classes))
  119. def test_ni_sit_slicing(self):
  120. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  121. train=True, download=True)
  122. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  123. train=False, download=True)
  124. my_ni_benchmark = ni_benchmark(
  125. mnist_train, mnist_test, 5, shuffle=True, seed=1234)
  126. experience: NIExperience
  127. for batch_id, experience in enumerate(my_ni_benchmark.train_stream):
  128. self.assertEqual(batch_id, experience.current_experience)
  129. self.assertIsInstance(experience, NIExperience)
  130. self.assertEqual(1, len(my_ni_benchmark.test_stream))
  131. for batch_id, experience in enumerate(my_ni_benchmark.test_stream):
  132. self.assertEqual(batch_id, experience.current_experience)
  133. self.assertIsInstance(experience, NIExperience)
  134. iterable_slice = [3, 4, 1]
  135. sliced_stream = my_ni_benchmark.train_stream[iterable_slice]
  136. self.assertIsInstance(sliced_stream, GenericScenarioStream)
  137. self.assertEqual(len(iterable_slice), len(sliced_stream))
  138. self.assertEqual('train', sliced_stream.name)
  139. for batch_id, experience in enumerate(sliced_stream):
  140. self.assertEqual(
  141. iterable_slice[batch_id], experience.current_experience)
  142. self.assertIsInstance(experience, NIExperience)
  143. with self.assertRaises(IndexError):
  144. # The test stream only has one element (the complete test set)
  145. sliced_stream = my_ni_benchmark.test_stream[iterable_slice]
  146. iterable_slice = [0, 0, 0]
  147. sliced_stream = my_ni_benchmark.test_stream[iterable_slice]
  148. self.assertIsInstance(sliced_stream, GenericScenarioStream)
  149. self.assertEqual(len(iterable_slice), len(sliced_stream))
  150. self.assertEqual('test', sliced_stream.name)
  151. for batch_id, experience in enumerate(sliced_stream):
  152. self.assertEqual(
  153. iterable_slice[batch_id], experience.current_experience)
  154. self.assertIsInstance(experience, NIExperience)
  155. if __name__ == '__main__':
  156. unittest.main()