all_mnist_early_stopping.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 20-11-2020 #
  7. # Author(s): Vincenzo Lomonaco #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. Same example as in all_mnist.py, but using early stopping to dynamically stop
  13. the training procedure when the model converged instead of training for a
  14. fixed number of epochs.
  15. IMPORTANT: In this example we use the test set to detect when the
  16. generalization error stops decreasing. In practice, one should *never* use
  17. the test set for early stopping, but rather measure the generalization
  18. performance on a held-out validation set.
  19. """
  20. from __future__ import absolute_import
  21. from __future__ import division
  22. from __future__ import print_function
  23. import torch
  24. import argparse
  25. from torch.nn import CrossEntropyLoss
  26. from torch.optim import SGD
  27. from avalanche.benchmarks.classic import PermutedMNIST, RotatedMNIST, \
  28. SplitMNIST
  29. from avalanche.models import SimpleMLP
  30. from avalanche.training.plugins.early_stopping import EarlyStoppingPlugin
  31. from avalanche.training.strategies import Naive
  32. def main(args):
  33. # Device config
  34. device = torch.device(f"cuda:{args.cuda}"
  35. if torch.cuda.is_available() and
  36. args.cuda >= 0 else "cpu")
  37. # model
  38. model = SimpleMLP(num_classes=10)
  39. # Here we show all the MNIST variation we offer in the "classic" benchmarks
  40. if args.mnist_type == 'permuted':
  41. scenario = PermutedMNIST(n_experiences=5, seed=1)
  42. elif args.mnist_type == 'rotated':
  43. scenario = RotatedMNIST(
  44. n_experiences=5, rotations_list=[30, 60, 90, 120, 150], seed=1)
  45. else:
  46. scenario = SplitMNIST(n_experiences=5, seed=1)
  47. # Than we can extract the parallel train and test streams
  48. train_stream = scenario.train_stream
  49. test_stream = scenario.test_stream
  50. # Prepare for training & testing
  51. optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
  52. criterion = CrossEntropyLoss()
  53. # Continual learning strategy with default logger
  54. cl_strategy = Naive(
  55. model, optimizer, criterion, train_mb_size=32, train_epochs=100,
  56. eval_mb_size=32, device=device, eval_every=1,
  57. plugins=[EarlyStoppingPlugin(args.patience, 'test_stream')])
  58. # train and test loop
  59. results = []
  60. for train_task, test_task in zip(train_stream, test_stream):
  61. print("Current Classes: ", train_task.classes_in_this_experience)
  62. cl_strategy.train(train_task, eval_streams=[test_task])
  63. results.append(cl_strategy.eval(test_stream))
  64. if __name__ == '__main__':
  65. parser = argparse.ArgumentParser()
  66. parser.add_argument('--mnist_type', type=str, default='split',
  67. choices=['rotated', 'permuted', 'split'],
  68. help='Choose between MNIST variations: '
  69. 'rotated, permuted or split.')
  70. parser.add_argument('--cuda', type=int, default=0,
  71. help='Select zero-indexed cuda device. -1 to use CPU.')
  72. parser.add_argument('--patience', type=int, default=3,
  73. help='Number of epochs to wait without generalization'
  74. 'improvements before stopping the training .')
  75. args = parser.parse_args()
  76. main(args)