test_ar1.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 11-05-2021 #
  7. # Author(s): Antonio Carta #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. import torch
  12. import unittest
  13. from sklearn.datasets import make_classification
  14. from sklearn.model_selection import train_test_split
  15. from torch.utils.data import TensorDataset
  16. from avalanche.benchmarks import nc_benchmark
  17. from avalanche.training.strategies import AR1
  18. from tests.training.test_strategies import StrategyTest
  19. class AR1Test(unittest.TestCase):
  20. def test_ar1(self):
  21. my_nc_benchmark = self.load_ar1_benchmark()
  22. strategy = AR1(train_epochs=1, train_mb_size=10, eval_mb_size=10,
  23. rm_sz=200)
  24. StrategyTest.run_strategy(self, my_nc_benchmark, strategy)
  25. def load_ar1_benchmark(self):
  26. """
  27. Returns a NC benchmark from a fake dataset of 10 classes, 5 experiences,
  28. 2 classes per experience. This toy benchmark is intended
  29. """
  30. n_samples_per_class = 50
  31. dataset = make_classification(
  32. n_samples=10 * n_samples_per_class,
  33. n_classes=10,
  34. n_features=224 * 224 * 3, n_informative=6, n_redundant=0)
  35. X = torch.from_numpy(dataset[0]).reshape(-1, 3, 224, 224).float()
  36. y = torch.from_numpy(dataset[1]).long()
  37. train_X, test_X, train_y, test_y = train_test_split(
  38. X, y, train_size=0.6, shuffle=True, stratify=y)
  39. train_dataset = TensorDataset(train_X, train_y)
  40. test_dataset = TensorDataset(test_X, test_y)
  41. my_nc_benchmark = nc_benchmark(
  42. train_dataset, test_dataset, 5, task_labels=False
  43. )
  44. return my_nc_benchmark