123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import unittest
- from torch.utils.data.dataloader import DataLoader
- from avalanche.benchmarks import Experience, SplitCIFAR100, SplitCIFAR110
- from tests.unit_tests_utils import load_experience_train_eval
- CIFAR10_DOWNLOADS = 0
- CIFAR10_DOWNLOAD_METHOD = None
- CIFAR100_DOWNLOADS = 0
- CIFAR100_DOWNLOAD_METHOD = None
- class CIFAR100BenchmarksTests(unittest.TestCase):
- def setUp(self):
- import avalanche.benchmarks.classic.ccifar100 as ccifar100
- global CIFAR10_DOWNLOAD_METHOD, CIFAR100_DOWNLOAD_METHOD
- CIFAR10_DOWNLOAD_METHOD = ccifar100._get_cifar10_dataset
- CIFAR100_DOWNLOAD_METHOD = ccifar100._get_cifar100_dataset
- def count_downloads_c10(*args, **kwargs):
- global CIFAR10_DOWNLOADS
- CIFAR10_DOWNLOADS += 1
- return CIFAR10_DOWNLOAD_METHOD(*args, **kwargs)
- def count_downloads_c100(*args, **kwargs):
- global CIFAR100_DOWNLOADS
- CIFAR100_DOWNLOADS += 1
- return CIFAR100_DOWNLOAD_METHOD(*args, **kwargs)
- ccifar100._get_cifar10_dataset = count_downloads_c10
- ccifar100._get_cifar100_dataset = count_downloads_c100
- def tearDown(self):
- global CIFAR10_DOWNLOAD_METHOD, CIFAR100_DOWNLOAD_METHOD
- if CIFAR10_DOWNLOAD_METHOD is not None:
- import avalanche.benchmarks.classic.ccifar100 as ccifar100
- ccifar100._get_cifar10_dataset = CIFAR10_DOWNLOAD_METHOD
- CIFAR10_DOWNLOAD_METHOD = None
- if CIFAR100_DOWNLOAD_METHOD is not None:
- import avalanche.benchmarks.classic.ccifar100 as ccifar100
- ccifar100._get_cifar100_dataset = CIFAR100_DOWNLOAD_METHOD
- CIFAR100_DOWNLOAD_METHOD = None
- def test_SplitCifar100_benchmark(self):
- benchmark = SplitCIFAR100(5)
- self.assertEqual(5, len(benchmark.train_stream))
- self.assertEqual(5, len(benchmark.test_stream))
- train_sz = 0
- for experience in benchmark.train_stream:
- self.assertIsInstance(experience, Experience)
- train_sz += len(experience.dataset)
- # Regression test for 575
- load_experience_train_eval(experience)
- self.assertEqual(50000, train_sz)
- test_sz = 0
- for experience in benchmark.test_stream:
- self.assertIsInstance(experience, Experience)
- test_sz += len(experience.dataset)
- # Regression test for 575
- load_experience_train_eval(experience)
- self.assertEqual(10000, test_sz)
- def test_SplitCifar110_benchmark(self):
- benchmark = SplitCIFAR110(6)
- self.assertEqual(6, len(benchmark.train_stream))
- self.assertEqual(6, len(benchmark.test_stream))
- train_sz = 0
- for experience in benchmark.train_stream:
- self.assertIsInstance(experience, Experience)
- train_sz += len(experience.dataset)
- load_experience_train_eval(experience)
- self.assertEqual(50000 * 2, train_sz)
- test_sz = 0
- for experience in benchmark.test_stream:
- self.assertIsInstance(experience, Experience)
- test_sz += len(experience.dataset)
- load_experience_train_eval(experience)
- self.assertEqual(10000 * 2, test_sz)
- def test_SplitCifar100_benchmark_download_once(self):
- global CIFAR100_DOWNLOADS, CIFAR10_DOWNLOADS
- CIFAR100_DOWNLOADS = 0
- CIFAR10_DOWNLOADS = 0
- benchmark = SplitCIFAR100(5)
- self.assertEqual(5, len(benchmark.train_stream))
- self.assertEqual(5, len(benchmark.test_stream))
- self.assertEqual(1, CIFAR100_DOWNLOADS)
- self.assertEqual(0, CIFAR10_DOWNLOADS)
- def test_SplitCifar110_benchmark_download_once(self):
- global CIFAR100_DOWNLOADS, CIFAR10_DOWNLOADS
- CIFAR100_DOWNLOADS = 0
- CIFAR10_DOWNLOADS = 0
- benchmark = SplitCIFAR110(6)
- self.assertEqual(6, len(benchmark.train_stream))
- self.assertEqual(6, len(benchmark.test_stream))
- self.assertEqual(1, CIFAR100_DOWNLOADS)
- self.assertEqual(1, CIFAR10_DOWNLOADS)
- if __name__ == '__main__':
- unittest.main()
|