gem_agem_mnist.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import torch
  2. import argparse
  3. from avalanche.benchmarks import PermutedMNIST, SplitMNIST
  4. from avalanche.training.strategies import GEM, AGEM
  5. from avalanche.models import SimpleMLP
  6. from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics, \
  7. loss_metrics
  8. from avalanche.logging import InteractiveLogger
  9. from avalanche.training.plugins import EvaluationPlugin
  10. """
  11. This example tests both GEM and A-GEM on Split MNIST and Permuted MNIST.
  12. GEM is a streaming strategy, that is it uses only 1 training epochs.
  13. A-GEM may use a larger number of epochs.
  14. Both GEM and A-GEM work with small mini batches (usually with 10 patterns).
  15. Warning1: This implementation of GEM and A-GEM does not use task vectors.
  16. Warning2: GEM is much slower than A-GEM.
  17. Results (learning rate is always 0.1):
  18. GEM-PMNIST (5 experiences):
  19. Hidden size 512. 1 training epoch. 512 patterns per experience, 0.5 memory
  20. strength. Average Accuracy over all experiences at the end of training on the
  21. last experience: 92.6%
  22. GEM-SMNIST:
  23. Patterns per experience: 256, Memory strength: 0.5, hidden size: 256
  24. Average Accuracy over all experiences at the end of training on the last
  25. experience: 93.3%
  26. AGEM-PMNIST (5 experiences):
  27. Patterns per experience = sample size: 256. 256 hidden size, 1 training epoch.
  28. Average Accuracy over all experiences at the end of training on the last
  29. experience: 83.5%
  30. AGEM-SMNIST:
  31. Patterns per experience = sample size: 256, 512, 1024. Performance on previous
  32. tasks remains very bad in terms of forgetting. Training epochs do not change
  33. result.
  34. Hidden size 256.
  35. Results for 1024 patterns per experience and sample size, 1 training epoch.
  36. Average Accuracy over all experiences at the end of training on the last
  37. experience: 67.0%
  38. """
  39. def main(args):
  40. model = SimpleMLP(hidden_size=args.hs)
  41. optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
  42. criterion = torch.nn.CrossEntropyLoss()
  43. # check if selected GPU is available or use CPU
  44. assert args.cuda == -1 or args.cuda >= 0, "cuda must be -1 or >= 0."
  45. device = torch.device(f"cuda:{args.cuda}"
  46. if torch.cuda.is_available() and
  47. args.cuda >= 0 else "cpu")
  48. print(f'Using device: {device}')
  49. # create scenario
  50. if args.scenario == 'pmnist':
  51. scenario = PermutedMNIST(n_experiences=args.permutations)
  52. elif args.scenario == 'smnist':
  53. scenario = SplitMNIST(n_experiences=5, return_task_id=False)
  54. else:
  55. raise ValueError("Wrong scenario name. Allowed pmnist, smnist.")
  56. # choose some metrics and evaluation method
  57. interactive_logger = InteractiveLogger()
  58. eval_plugin = EvaluationPlugin(
  59. accuracy_metrics(minibatch=True, epoch=True,
  60. experience=True, stream=True),
  61. loss_metrics(minibatch=True, epoch=True,
  62. experience=True, stream=True),
  63. forgetting_metrics(experience=True),
  64. loggers=[interactive_logger])
  65. # create strategy
  66. if args.strategy == 'gem':
  67. strategy = GEM(model, optimizer, criterion, args.patterns_per_exp,
  68. args.memory_strength, train_epochs=args.epochs,
  69. device=device, train_mb_size=10, evaluator=eval_plugin)
  70. elif args.strategy == 'agem':
  71. strategy = AGEM(model, optimizer, criterion, args.patterns_per_exp,
  72. args.sample_size, train_epochs=args.epochs,
  73. device=device, train_mb_size=10, evaluator=eval_plugin)
  74. else:
  75. raise ValueError("Wrong strategy name. Allowed gem, agem.")
  76. # train on the selected scenario with the chosen strategy
  77. print('Starting experiment...')
  78. results = []
  79. for experience in scenario.train_stream:
  80. print("Start training on experience ", experience.current_experience)
  81. strategy.train(experience)
  82. print("End training on experience ", experience.current_experience)
  83. print('Computing accuracy on the test set')
  84. results.append(strategy.eval(scenario.test_stream[:]))
  85. if __name__ == '__main__':
  86. parser = argparse.ArgumentParser()
  87. parser.add_argument('--strategy', type=str, choices=['gem', 'agem'],
  88. default='gem', help='Choose between GEM and A-GEM')
  89. parser.add_argument('--scenario', type=str,
  90. choices=['pmnist', 'smnist'], default='smnist',
  91. help='Choose between Permuted MNIST, Split MNIST.')
  92. parser.add_argument('--patterns_per_exp', type=int, default=256,
  93. help='Patterns to store in the memory for each'
  94. ' experience')
  95. parser.add_argument('--sample_size', type=int, default=256,
  96. help='Number of patterns to sample from memory when \
  97. projecting gradient. A-GEM only.')
  98. parser.add_argument('--memory_strength', type=float, default=0.5,
  99. help='Offset to add to the projection direction. '
  100. 'GEM only.')
  101. parser.add_argument('--lr', type=float, default=1e-1, help='Learning rate.')
  102. parser.add_argument('--hs', type=int, default=256, help='MLP hidden size.')
  103. parser.add_argument('--epochs', type=int, default=1,
  104. help='Number of training epochs.')
  105. parser.add_argument('--permutations', type=int, default=5,
  106. help='Number of experiences in Permuted MNIST.')
  107. parser.add_argument('--cuda', type=int, default=0,
  108. help='Specify GPU id to use. Use CPU if -1.')
  109. args = parser.parse_args()
  110. main(args)