test_cifar100_benchmarks.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import unittest
  2. from torch.utils.data.dataloader import DataLoader
  3. from avalanche.benchmarks import Experience, SplitCIFAR100, SplitCIFAR110
  4. from tests.unit_tests_utils import load_experience_train_eval
  5. CIFAR10_DOWNLOADS = 0
  6. CIFAR10_DOWNLOAD_METHOD = None
  7. CIFAR100_DOWNLOADS = 0
  8. CIFAR100_DOWNLOAD_METHOD = None
  9. class CIFAR100BenchmarksTests(unittest.TestCase):
  10. def setUp(self):
  11. import avalanche.benchmarks.classic.ccifar100 as ccifar100
  12. global CIFAR10_DOWNLOAD_METHOD, CIFAR100_DOWNLOAD_METHOD
  13. CIFAR10_DOWNLOAD_METHOD = ccifar100._get_cifar10_dataset
  14. CIFAR100_DOWNLOAD_METHOD = ccifar100._get_cifar100_dataset
  15. def count_downloads_c10(*args, **kwargs):
  16. global CIFAR10_DOWNLOADS
  17. CIFAR10_DOWNLOADS += 1
  18. return CIFAR10_DOWNLOAD_METHOD(*args, **kwargs)
  19. def count_downloads_c100(*args, **kwargs):
  20. global CIFAR100_DOWNLOADS
  21. CIFAR100_DOWNLOADS += 1
  22. return CIFAR100_DOWNLOAD_METHOD(*args, **kwargs)
  23. ccifar100._get_cifar10_dataset = count_downloads_c10
  24. ccifar100._get_cifar100_dataset = count_downloads_c100
  25. def tearDown(self):
  26. global CIFAR10_DOWNLOAD_METHOD, CIFAR100_DOWNLOAD_METHOD
  27. if CIFAR10_DOWNLOAD_METHOD is not None:
  28. import avalanche.benchmarks.classic.ccifar100 as ccifar100
  29. ccifar100._get_cifar10_dataset = CIFAR10_DOWNLOAD_METHOD
  30. CIFAR10_DOWNLOAD_METHOD = None
  31. if CIFAR100_DOWNLOAD_METHOD is not None:
  32. import avalanche.benchmarks.classic.ccifar100 as ccifar100
  33. ccifar100._get_cifar100_dataset = CIFAR100_DOWNLOAD_METHOD
  34. CIFAR100_DOWNLOAD_METHOD = None
  35. def test_SplitCifar100_benchmark(self):
  36. benchmark = SplitCIFAR100(5)
  37. self.assertEqual(5, len(benchmark.train_stream))
  38. self.assertEqual(5, len(benchmark.test_stream))
  39. train_sz = 0
  40. for experience in benchmark.train_stream:
  41. self.assertIsInstance(experience, Experience)
  42. train_sz += len(experience.dataset)
  43. # Regression test for 575
  44. load_experience_train_eval(experience)
  45. self.assertEqual(50000, train_sz)
  46. test_sz = 0
  47. for experience in benchmark.test_stream:
  48. self.assertIsInstance(experience, Experience)
  49. test_sz += len(experience.dataset)
  50. # Regression test for 575
  51. load_experience_train_eval(experience)
  52. self.assertEqual(10000, test_sz)
  53. def test_SplitCifar110_benchmark(self):
  54. benchmark = SplitCIFAR110(6)
  55. self.assertEqual(6, len(benchmark.train_stream))
  56. self.assertEqual(6, len(benchmark.test_stream))
  57. train_sz = 0
  58. for experience in benchmark.train_stream:
  59. self.assertIsInstance(experience, Experience)
  60. train_sz += len(experience.dataset)
  61. load_experience_train_eval(experience)
  62. self.assertEqual(50000 * 2, train_sz)
  63. test_sz = 0
  64. for experience in benchmark.test_stream:
  65. self.assertIsInstance(experience, Experience)
  66. test_sz += len(experience.dataset)
  67. load_experience_train_eval(experience)
  68. self.assertEqual(10000 * 2, test_sz)
  69. def test_SplitCifar100_benchmark_download_once(self):
  70. global CIFAR100_DOWNLOADS, CIFAR10_DOWNLOADS
  71. CIFAR100_DOWNLOADS = 0
  72. CIFAR10_DOWNLOADS = 0
  73. benchmark = SplitCIFAR100(5)
  74. self.assertEqual(5, len(benchmark.train_stream))
  75. self.assertEqual(5, len(benchmark.test_stream))
  76. self.assertEqual(1, CIFAR100_DOWNLOADS)
  77. self.assertEqual(0, CIFAR10_DOWNLOADS)
  78. def test_SplitCifar110_benchmark_download_once(self):
  79. global CIFAR100_DOWNLOADS, CIFAR10_DOWNLOADS
  80. CIFAR100_DOWNLOADS = 0
  81. CIFAR10_DOWNLOADS = 0
  82. benchmark = SplitCIFAR110(6)
  83. self.assertEqual(6, len(benchmark.train_stream))
  84. self.assertEqual(6, len(benchmark.test_stream))
  85. self.assertEqual(1, CIFAR100_DOWNLOADS)
  86. self.assertEqual(1, CIFAR10_DOWNLOADS)
  87. if __name__ == '__main__':
  88. unittest.main()