all_mnist.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 20-11-2020 #
  7. # Author(s): Vincenzo Lomonaco #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. In this simple example we show all the different ways you can use MNIST with
  13. Avalanche.
  14. """
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import torch
  19. import argparse
  20. from torch.nn import CrossEntropyLoss
  21. from torch.optim import SGD
  22. from avalanche.benchmarks.classic import PermutedMNIST, RotatedMNIST, \
  23. SplitMNIST
  24. from avalanche.models import SimpleMLP
  25. from avalanche.training.strategies import Naive
  26. def main(args):
  27. # Device config
  28. device = torch.device(f"cuda:{args.cuda}"
  29. if torch.cuda.is_available() and
  30. args.cuda >= 0 else "cpu")
  31. # model
  32. model = SimpleMLP(num_classes=10)
  33. # Here we show all the MNIST variation we offer in the "classic" benchmarks
  34. if args.mnist_type == 'permuted':
  35. scenario = PermutedMNIST(n_experiences=5, seed=1)
  36. elif args.mnist_type == 'rotated':
  37. scenario = RotatedMNIST(
  38. n_experiences=5, rotations_list=[30, 60, 90, 120, 150], seed=1)
  39. else:
  40. scenario = SplitMNIST(n_experiences=5, seed=1)
  41. # Than we can extract the parallel train and test streams
  42. train_stream = scenario.train_stream
  43. test_stream = scenario.test_stream
  44. # Prepare for training & testing
  45. optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
  46. criterion = CrossEntropyLoss()
  47. # Continual learning strategy with default logger
  48. cl_strategy = Naive(
  49. model, optimizer, criterion, train_mb_size=32, train_epochs=2,
  50. eval_mb_size=32, device=device)
  51. # train and test loop
  52. results = []
  53. for train_task in train_stream:
  54. print("Current Classes: ", train_task.classes_in_this_experience)
  55. cl_strategy.train(train_task)
  56. results.append(cl_strategy.eval(test_stream))
  57. if __name__ == '__main__':
  58. parser = argparse.ArgumentParser()
  59. parser.add_argument('--mnist_type', type=str, default='split',
  60. choices=['rotated', 'permuted', 'split'],
  61. help='Choose between MNIST variations: '
  62. 'rotated, permuted or split.')
  63. parser.add_argument('--cuda', type=int, default=0,
  64. help='Select zero-indexed cuda device. -1 to use CPU.')
  65. args = parser.parse_args()
  66. main(args)