test_cifar10_benchmarks.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import unittest
  2. from avalanche.benchmarks import Experience, SplitCIFAR10
  3. from tests.unit_tests_utils import load_experience_train_eval
  4. CIFAR10_DOWNLOADS = 0
  5. CIFAR10_DOWNLOAD_METHOD = None
  6. class CIFAR10BenchmarksTests(unittest.TestCase):
  7. def setUp(self):
  8. import avalanche.benchmarks.classic.ccifar10 as ccifar10
  9. global CIFAR10_DOWNLOAD_METHOD
  10. CIFAR10_DOWNLOAD_METHOD = ccifar10._get_cifar10_dataset
  11. def count_downloads(*args, **kwargs):
  12. global CIFAR10_DOWNLOADS
  13. CIFAR10_DOWNLOADS += 1
  14. return CIFAR10_DOWNLOAD_METHOD(*args, **kwargs)
  15. ccifar10._get_cifar10_dataset = count_downloads
  16. def tearDown(self):
  17. global CIFAR10_DOWNLOAD_METHOD
  18. if CIFAR10_DOWNLOAD_METHOD is not None:
  19. import avalanche.benchmarks.classic.ccifar10 as ccifar10
  20. ccifar10._get_cifar10_dataset = CIFAR10_DOWNLOAD_METHOD
  21. CIFAR10_DOWNLOAD_METHOD = None
  22. def test_SplitCifar10_benchmark(self):
  23. benchmark = SplitCIFAR10(5)
  24. self.assertEqual(5, len(benchmark.train_stream))
  25. self.assertEqual(5, len(benchmark.test_stream))
  26. train_sz = 0
  27. for experience in benchmark.train_stream:
  28. self.assertIsInstance(experience, Experience)
  29. train_sz += len(experience.dataset)
  30. # Regression test for 575
  31. load_experience_train_eval(experience)
  32. self.assertEqual(50000, train_sz)
  33. test_sz = 0
  34. for experience in benchmark.test_stream:
  35. self.assertIsInstance(experience, Experience)
  36. test_sz += len(experience.dataset)
  37. # Regression test for 575
  38. load_experience_train_eval(experience)
  39. self.assertEqual(10000, test_sz)
  40. def test_SplitCifar10_benchmark_download_once(self):
  41. global CIFAR10_DOWNLOADS
  42. CIFAR10_DOWNLOADS = 0
  43. benchmark = SplitCIFAR10(5)
  44. self.assertEqual(5, len(benchmark.train_stream))
  45. self.assertEqual(5, len(benchmark.test_stream))
  46. self.assertEqual(1, CIFAR10_DOWNLOADS)
  47. if __name__ == '__main__':
  48. unittest.main()