123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- import unittest
- from os.path import expanduser
- import torch
- from torchvision.datasets import MNIST
- from avalanche.benchmarks.scenarios import NIExperience, GenericScenarioStream
- from avalanche.benchmarks.utils import AvalancheSubset
- from avalanche.benchmarks.scenarios.new_classes.nc_utils import \
- make_nc_transformation_subset
- from avalanche.benchmarks import ni_benchmark
- class NISITTests(unittest.TestCase):
- def test_ni_sit_single_dataset(self):
- mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=True, download=True)
- mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=False, download=True)
- my_ni_benchmark = ni_benchmark(
- mnist_train, mnist_test, 5, shuffle=True, seed=1234,
- balance_experiences=True)
- self.assertEqual(5, my_ni_benchmark.n_experiences)
- self.assertEqual(10, my_ni_benchmark.n_classes)
- for batch_id in range(5):
- self.assertEqual(
- 10,
- len(my_ni_benchmark.classes_in_experience['train'][batch_id])
- )
- _, unique_count = torch.unique(torch.as_tensor(mnist_train.targets),
- return_counts=True)
- min_batch_size = torch.sum(unique_count //
- my_ni_benchmark.n_experiences).item()
- max_batch_size = min_batch_size + my_ni_benchmark.n_classes
- pattern_count = 0
- batch_info: NIExperience
- for batch_id, batch_info in enumerate(my_ni_benchmark.train_stream):
- cur_train_set = batch_info.dataset
- t = batch_info.task_label
- self.assertEqual(0, t)
- self.assertEqual(batch_id, batch_info.current_experience)
- self.assertGreaterEqual(len(cur_train_set), min_batch_size)
- self.assertLessEqual(len(cur_train_set), max_batch_size)
- pattern_count += len(cur_train_set)
- self.assertEqual(len(mnist_train), pattern_count)
- self.assertEqual(1, len(my_ni_benchmark.test_stream))
- pattern_count = 0
- for batch_id, batch_info in enumerate(my_ni_benchmark.test_stream):
- cur_test_set = batch_info.dataset
- t = batch_info.task_label
- self.assertEqual(0, t)
- self.assertEqual(batch_id, batch_info.current_experience)
- pattern_count += len(cur_test_set)
- self.assertEqual(len(mnist_test), pattern_count)
- def test_ni_sit_single_dataset_fixed_assignment(self):
- mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=True, download=True)
- mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=False, download=True)
- ni_benchmark_reference = ni_benchmark(
- mnist_train, mnist_test, 5, shuffle=True, seed=1234)
- reference_assignment = ni_benchmark_reference.\
- train_exps_patterns_assignment
- my_ni_benchmark = ni_benchmark(
- mnist_train, mnist_test, 5, shuffle=True, seed=4321,
- fixed_exp_assignment=reference_assignment)
- self.assertEqual(
- ni_benchmark_reference.n_experiences, my_ni_benchmark.n_experiences
- )
- self.assertEqual(ni_benchmark_reference.train_exps_patterns_assignment,
- my_ni_benchmark.train_exps_patterns_assignment)
- self.assertEqual(ni_benchmark_reference.exp_structure,
- my_ni_benchmark.exp_structure)
- def test_ni_sit_single_dataset_reproducibility_data(self):
- mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=True, download=True)
- mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=False, download=True)
- ni_benchmark_reference = ni_benchmark(
- mnist_train, mnist_test, 5, shuffle=True, seed=1234)
- rep_data = ni_benchmark_reference.get_reproducibility_data()
- my_ni_benchmark = ni_benchmark(
- mnist_train, mnist_test, 0, reproducibility_data=rep_data)
- self.assertEqual(
- ni_benchmark_reference.n_experiences, my_ni_benchmark.n_experiences
- )
- self.assertEqual(ni_benchmark_reference.train_exps_patterns_assignment,
- my_ni_benchmark.train_exps_patterns_assignment)
- self.assertEqual(ni_benchmark_reference.exp_structure,
- my_ni_benchmark.exp_structure)
- def test_ni_sit_multi_dataset_merge(self):
- split_mapping = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
- mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=True, download=True)
- mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=False, download=True)
- train_part1 = make_nc_transformation_subset(
- mnist_train, None, None, range(5))
- train_part2 = make_nc_transformation_subset(
- mnist_train, None, None, range(5, 10))
- train_part2 = AvalancheSubset(
- train_part2, class_mapping=split_mapping)
- test_part1 = make_nc_transformation_subset(
- mnist_test, None, None, range(5))
- test_part2 = make_nc_transformation_subset(
- mnist_test, None, None, range(5, 10))
- test_part2 = AvalancheSubset(test_part2,
- class_mapping=split_mapping)
- my_ni_benchmark = ni_benchmark(
- [train_part1, train_part2], [test_part1, test_part2], 5,
- shuffle=True, seed=1234, balance_experiences=True)
- self.assertEqual(5, my_ni_benchmark.n_experiences)
- self.assertEqual(10, my_ni_benchmark.n_classes)
- for batch_id in range(5):
- self.assertEqual(
- 10,
- len(my_ni_benchmark.classes_in_experience['train'][batch_id]))
- all_classes = set()
- for batch_id in range(5):
- all_classes.update(
- my_ni_benchmark.classes_in_experience['train'][batch_id])
- self.assertEqual(10, len(all_classes))
- def test_ni_sit_slicing(self):
- mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=True, download=True)
- mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=False, download=True)
- my_ni_benchmark = ni_benchmark(
- mnist_train, mnist_test, 5, shuffle=True, seed=1234)
- experience: NIExperience
- for batch_id, experience in enumerate(my_ni_benchmark.train_stream):
- self.assertEqual(batch_id, experience.current_experience)
- self.assertIsInstance(experience, NIExperience)
- self.assertEqual(1, len(my_ni_benchmark.test_stream))
- for batch_id, experience in enumerate(my_ni_benchmark.test_stream):
- self.assertEqual(batch_id, experience.current_experience)
- self.assertIsInstance(experience, NIExperience)
- iterable_slice = [3, 4, 1]
- sliced_stream = my_ni_benchmark.train_stream[iterable_slice]
- self.assertIsInstance(sliced_stream, GenericScenarioStream)
- self.assertEqual(len(iterable_slice), len(sliced_stream))
- self.assertEqual('train', sliced_stream.name)
- for batch_id, experience in enumerate(sliced_stream):
- self.assertEqual(
- iterable_slice[batch_id], experience.current_experience)
- self.assertIsInstance(experience, NIExperience)
- with self.assertRaises(IndexError):
- # The test stream only has one element (the complete test set)
- sliced_stream = my_ni_benchmark.test_stream[iterable_slice]
- iterable_slice = [0, 0, 0]
- sliced_stream = my_ni_benchmark.test_stream[iterable_slice]
- self.assertIsInstance(sliced_stream, GenericScenarioStream)
- self.assertEqual(len(iterable_slice), len(sliced_stream))
- self.assertEqual('test', sliced_stream.name)
- for batch_id, experience in enumerate(sliced_stream):
- self.assertEqual(
- iterable_slice[batch_id], experience.current_experience)
- self.assertIsInstance(experience, NIExperience)
- if __name__ == '__main__':
- unittest.main()
|