lwf_mnist.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import torch
  2. import argparse
  3. from avalanche.benchmarks import SplitMNIST
  4. from avalanche.training.strategies import LwF
  5. from avalanche.models import SimpleMLP
  6. from avalanche.evaluation.metrics import forgetting_metrics, \
  7. accuracy_metrics, loss_metrics
  8. from avalanche.logging import InteractiveLogger
  9. from avalanche.training.plugins import EvaluationPlugin
  10. """
  11. This example tests Learning without Forgetting (LwF) on Split MNIST.
  12. The performance with default arguments should give an average accuracy
  13. of about 73%.
  14. """
  15. def main(args):
  16. model = SimpleMLP(hidden_size=args.hs)
  17. optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
  18. criterion = torch.nn.CrossEntropyLoss()
  19. # check if selected GPU is available or use CPU
  20. assert args.cuda == -1 or args.cuda >= 0, "cuda must be -1 or >= 0."
  21. device = torch.device(f"cuda:{args.cuda}"
  22. if torch.cuda.is_available() and
  23. args.cuda >= 0 else "cpu")
  24. print(f'Using device: {device}')
  25. # create split scenario
  26. scenario = SplitMNIST(n_experiences=5, return_task_id=False)
  27. interactive_logger = InteractiveLogger()
  28. eval_plugin = EvaluationPlugin(
  29. accuracy_metrics(
  30. minibatch=True, epoch=True, experience=True, stream=True),
  31. loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
  32. forgetting_metrics(experience=True),
  33. loggers=[interactive_logger])
  34. # create strategy
  35. assert len(args.lwf_alpha) == 1 or len(args.lwf_alpha) == 5,\
  36. 'Alpha must be a non-empty list.'
  37. lwf_alpha = args.lwf_alpha[0] if len(args.lwf_alpha) == 1 \
  38. else args.lwf_alpha
  39. strategy = LwF(model, optimizer, criterion, alpha=lwf_alpha,
  40. temperature=args.softmax_temperature,
  41. train_epochs=args.epochs, device=device,
  42. train_mb_size=args.minibatch_size, evaluator=eval_plugin)
  43. # train on the selected scenario with the chosen strategy
  44. print('Starting experiment...')
  45. results = []
  46. for train_batch_info in scenario.train_stream:
  47. print("Start training on experience ",
  48. train_batch_info.current_experience)
  49. strategy.train(train_batch_info, num_workers=0)
  50. print("End training on experience ",
  51. train_batch_info.current_experience)
  52. print('Computing accuracy on the test set')
  53. results.append(strategy.eval(scenario.test_stream[:]))
  54. if __name__ == '__main__':
  55. parser = argparse.ArgumentParser()
  56. parser.add_argument('--lwf_alpha', nargs='+', type=float,
  57. default=[0, 0.5, 1.333, 2.25, 3.2],
  58. help='Penalty hyperparameter for LwF. It can be either'
  59. 'a list with multiple elements (one alpha per '
  60. 'experience) or a list of one element (same alpha '
  61. 'for all experiences).')
  62. parser.add_argument('--softmax_temperature', type=float, default=1,
  63. help='Temperature for softmax used in distillation')
  64. parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.')
  65. parser.add_argument('--hs', type=int, default=256, help='MLP hidden size.')
  66. parser.add_argument('--epochs', type=int, default=10,
  67. help='Number of training epochs.')
  68. parser.add_argument('--minibatch_size', type=int, default=128,
  69. help='Minibatch size.')
  70. parser.add_argument('--cuda', type=int, default=0,
  71. help='Specify GPU id to use. Use CPU if -1.')
  72. args = parser.parse_args()
  73. main(args)