test_mnist_benckmarks.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import unittest
  2. from avalanche.benchmarks import PermutedMNIST, Experience, RotatedMNIST, \
  3. SplitMNIST
  4. from tests.unit_tests_utils import load_experience_train_eval
  5. MNIST_DOWNLOADS = 0
  6. MNIST_DOWNLOAD_METHOD = None
  7. class MNISTBenchmarksTests(unittest.TestCase):
  8. def setUp(self):
  9. import avalanche.benchmarks.classic.cmnist as cmnist
  10. global MNIST_DOWNLOAD_METHOD
  11. MNIST_DOWNLOAD_METHOD = cmnist._get_mnist_dataset
  12. def count_downloads(*args, **kwargs):
  13. global MNIST_DOWNLOADS
  14. MNIST_DOWNLOADS += 1
  15. return MNIST_DOWNLOAD_METHOD(*args, **kwargs)
  16. cmnist._get_mnist_dataset = count_downloads
  17. def tearDown(self):
  18. global MNIST_DOWNLOAD_METHOD
  19. if MNIST_DOWNLOAD_METHOD is not None:
  20. import avalanche.benchmarks.classic.cmnist as cmnist
  21. cmnist._get_mnist_dataset = MNIST_DOWNLOAD_METHOD
  22. MNIST_DOWNLOAD_METHOD = None
  23. def test_SplitMNIST_benchmark(self):
  24. benchmark = SplitMNIST(5)
  25. self.assertEqual(5, len(benchmark.train_stream))
  26. self.assertEqual(5, len(benchmark.test_stream))
  27. train_sz = 0
  28. for experience in benchmark.train_stream:
  29. self.assertIsInstance(experience, Experience)
  30. train_sz += len(experience.dataset)
  31. # Regression test for 572
  32. load_experience_train_eval(experience)
  33. self.assertEqual(60000, train_sz)
  34. test_sz = 0
  35. for experience in benchmark.test_stream:
  36. self.assertIsInstance(experience, Experience)
  37. test_sz += len(experience.dataset)
  38. # Regression test for 572
  39. load_experience_train_eval(experience)
  40. self.assertEqual(10000, test_sz)
  41. def test_PermutedMNIST_benchmark(self):
  42. benchmark = PermutedMNIST(3)
  43. self.assertEqual(3, len(benchmark.train_stream))
  44. self.assertEqual(3, len(benchmark.test_stream))
  45. for experience in benchmark.train_stream:
  46. self.assertIsInstance(experience, Experience)
  47. self.assertEqual(60000, len(experience.dataset))
  48. load_experience_train_eval(experience)
  49. for experience in benchmark.test_stream:
  50. self.assertIsInstance(experience, Experience)
  51. self.assertEqual(10000, len(experience.dataset))
  52. load_experience_train_eval(experience)
  53. def test_RotatedMNIST_benchmark(self):
  54. benchmark = RotatedMNIST(3)
  55. self.assertEqual(3, len(benchmark.train_stream))
  56. self.assertEqual(3, len(benchmark.test_stream))
  57. for experience in benchmark.train_stream:
  58. self.assertIsInstance(experience, Experience)
  59. self.assertEqual(60000, len(experience.dataset))
  60. load_experience_train_eval(experience)
  61. for experience in benchmark.test_stream:
  62. self.assertIsInstance(experience, Experience)
  63. self.assertEqual(10000, len(experience.dataset))
  64. load_experience_train_eval(experience)
  65. def test_PermutedMNIST_benchmark_download_once(self):
  66. global MNIST_DOWNLOADS
  67. MNIST_DOWNLOADS = 0
  68. benchmark = PermutedMNIST(3)
  69. self.assertEqual(3, len(benchmark.train_stream))
  70. self.assertEqual(3, len(benchmark.test_stream))
  71. self.assertEqual(1, MNIST_DOWNLOADS)
  72. def test_RotatedMNIST_benchmark_download_once(self):
  73. global MNIST_DOWNLOADS
  74. MNIST_DOWNLOADS = 0
  75. benchmark = RotatedMNIST(3)
  76. self.assertEqual(3, len(benchmark.train_stream))
  77. self.assertEqual(3, len(benchmark.test_stream))
  78. self.assertEqual(1, MNIST_DOWNLOADS)
  79. # def test_PermutedMNIST_benchmark_performance(self):
  80. # import time
  81. # from torch.utils.data.dataloader import DataLoader
  82. # start_time = time.time()
  83. # benchmark = PermutedMNIST(10)
  84. #
  85. # for experience in benchmark.train_stream:
  86. # self.assertIsInstance(experience, Experience)
  87. # self.assertEqual(60000, len(experience.dataset))
  88. # all_targets = sum(experience.dataset.targets)
  89. #
  90. # # dataset = experience.dataset
  91. # # loader = DataLoader(dataset, num_workers=4, shuffle=True,
  92. # # batch_size=256)
  93. # # for batch in loader:
  94. # # x, y, t = batch
  95. #
  96. # for experience in benchmark.test_stream:
  97. # self.assertIsInstance(experience, Experience)
  98. # self.assertEqual(10000, len(experience.dataset))
  99. # all_targets = sum(experience.dataset.targets)
  100. #
  101. # # dataset = experience.dataset
  102. # # loader = DataLoader(dataset, num_workers=4, shuffle=True,
  103. # # batch_size=256)
  104. # # for batch in loader:
  105. # # x, y, t = batch
  106. #
  107. # elapsed_time = time.time() - start_time
  108. # print('Elapsed:', elapsed_time)
  109. if __name__ == '__main__':
  110. unittest.main()