test_fmnist_benckmarks.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import unittest
  2. from avalanche.benchmarks import Experience, SplitFMNIST
  3. from tests.unit_tests_utils import load_experience_train_eval
  4. MNIST_DOWNLOADS = 0
  5. MNIST_DOWNLOAD_METHOD = None
  6. class FMNISTBenchmarksTests(unittest.TestCase):
  7. def setUp(self):
  8. import avalanche.benchmarks.classic.cfashion_mnist as cfashion_mnist
  9. global MNIST_DOWNLOAD_METHOD
  10. MNIST_DOWNLOAD_METHOD = cfashion_mnist._get_fmnist_dataset
  11. def count_downloads(*args, **kwargs):
  12. global MNIST_DOWNLOADS
  13. MNIST_DOWNLOADS += 1
  14. return MNIST_DOWNLOAD_METHOD(*args, **kwargs)
  15. cfashion_mnist._get_fmnist_dataset = count_downloads
  16. def tearDown(self):
  17. global MNIST_DOWNLOAD_METHOD
  18. if MNIST_DOWNLOAD_METHOD is not None:
  19. import avalanche.benchmarks.classic.cfashion_mnist as cfashion_mnist
  20. cfashion_mnist._get_fmnist_dataset = MNIST_DOWNLOAD_METHOD
  21. MNIST_DOWNLOAD_METHOD = None
  22. def test_SplitFMNIST_benchmark(self):
  23. benchmark = SplitFMNIST(5)
  24. self.assertEqual(5, len(benchmark.train_stream))
  25. self.assertEqual(5, len(benchmark.test_stream))
  26. train_sz = 0
  27. for experience in benchmark.train_stream:
  28. self.assertIsInstance(experience, Experience)
  29. train_sz += len(experience.dataset)
  30. load_experience_train_eval(experience)
  31. self.assertEqual(60000, train_sz)
  32. test_sz = 0
  33. for experience in benchmark.test_stream:
  34. self.assertIsInstance(experience, Experience)
  35. test_sz += len(experience.dataset)
  36. load_experience_train_eval(experience)
  37. self.assertEqual(10000, test_sz)
  38. if __name__ == '__main__':
  39. unittest.main()