ewc_mnist.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import torch
  2. from os.path import expanduser
  3. import argparse
  4. from torchvision.datasets import MNIST
  5. from torchvision.transforms import ToTensor
  6. from avalanche.benchmarks import PermutedMNIST, nc_benchmark
  7. from avalanche.training.strategies import EWC
  8. from avalanche.models import SimpleMLP
  9. from avalanche.evaluation.metrics import forgetting_metrics, \
  10. accuracy_metrics, loss_metrics, bwt_metrics
  11. from avalanche.logging import InteractiveLogger, TensorboardLogger
  12. from avalanche.training.plugins import EvaluationPlugin
  13. """
  14. This example tests EWC on Split MNIST and Permuted MNIST.
  15. It is possible to choose, among other options, between EWC with separate
  16. penalties and online EWC with a single penalty.
  17. On Permuted MNIST EWC maintains a very good performance on previous tasks
  18. with a wide range of configurations. The average accuracy on previous tasks
  19. at the end of training on all task is around 85%,
  20. with a comparable training accuracy.
  21. On Split MNIST, on the contrary, EWC is not able to remember previous tasks and
  22. is subjected to complete forgetting in all configurations. The training accuracy
  23. is above 90% but the average accuracy on previou tasks is around 20%.
  24. """
  25. def main(args):
  26. model = SimpleMLP(hidden_size=args.hs)
  27. optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
  28. criterion = torch.nn.CrossEntropyLoss()
  29. # check if selected GPU is available or use CPU
  30. assert args.cuda == -1 or args.cuda >= 0, "cuda must be -1 or >= 0."
  31. device = torch.device(f"cuda:{args.cuda}"
  32. if torch.cuda.is_available() and
  33. args.cuda >= 0 else "cpu")
  34. print(f'Using device: {device}')
  35. # create scenario
  36. if args.scenario == 'pmnist':
  37. scenario = PermutedMNIST(n_experiences=args.permutations)
  38. elif args.scenario == 'smnist':
  39. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  40. train=True, download=True, transform=ToTensor())
  41. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  42. train=False, download=True, transform=ToTensor())
  43. scenario = nc_benchmark(
  44. mnist_train, mnist_test, 5, task_labels=False, seed=1234)
  45. else:
  46. raise ValueError("Wrong scenario name. Allowed pmnist, smnist.")
  47. # choose some metrics and evaluation method
  48. interactive_logger = InteractiveLogger()
  49. tensorboard_logger = TensorboardLogger()
  50. eval_plugin = EvaluationPlugin(
  51. accuracy_metrics(
  52. minibatch=True, epoch=True, experience=True, stream=True),
  53. loss_metrics(
  54. minibatch=True, epoch=True, experience=True, stream=True),
  55. forgetting_metrics(experience=True, stream=True),
  56. bwt_metrics(experience=True, stream=True),
  57. loggers=[interactive_logger, tensorboard_logger])
  58. # create strategy
  59. strategy = EWC(model, optimizer, criterion, args.ewc_lambda,
  60. args.ewc_mode, decay_factor=args.decay_factor,
  61. train_epochs=args.epochs, device=device,
  62. train_mb_size=args.minibatch_size, evaluator=eval_plugin)
  63. # train on the selected scenario with the chosen strategy
  64. print('Starting experiment...')
  65. results = []
  66. for experience in scenario.train_stream:
  67. print("Start training on experience ", experience.current_experience)
  68. strategy.train(experience)
  69. print("End training on experience", experience.current_experience)
  70. print('Computing accuracy on the test set')
  71. results.append(strategy.eval(scenario.test_stream[:]))
  72. if __name__ == '__main__':
  73. parser = argparse.ArgumentParser()
  74. parser.add_argument('--scenario', type=str,
  75. choices=['pmnist', 'smnist'], default='smnist',
  76. help='Choose between Permuted MNIST, Split MNIST.')
  77. parser.add_argument('--ewc_mode', type=str, choices=['separate', 'online'],
  78. default='separate',
  79. help='Choose between EWC and online.')
  80. parser.add_argument('--ewc_lambda', type=float, default=0.4,
  81. help='Penalty hyperparameter for EWC')
  82. parser.add_argument('--decay_factor', type=float, default=0.1,
  83. help='Decay factor for importance '
  84. 'when ewc_mode is online.')
  85. parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.')
  86. parser.add_argument('--hs', type=int, default=256, help='MLP hidden size.')
  87. parser.add_argument('--epochs', type=int, default=10,
  88. help='Number of training epochs.')
  89. parser.add_argument('--minibatch_size', type=int, default=128,
  90. help='Minibatch size.')
  91. parser.add_argument('--permutations', type=int, default=5,
  92. help='Number of experiences in Permuted MNIST.')
  93. parser.add_argument('--cuda', type=int, default=0,
  94. help='Specify GPU id to use. Use CPU if -1.')
  95. args = parser.parse_args()
  96. main(args)