multihead.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 01-12-2020 #
  7. # Author(s): Andrea Cossu #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This example trains a Multi-head model on Split MNIST with Elastich Weight
  13. Consolidation. Each experience has a different task label, which is used at test
  14. time to select the appropriate head.
  15. """
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import argparse
  20. import torch
  21. from torch.nn import CrossEntropyLoss
  22. from torch.optim import Adam
  23. from avalanche.benchmarks.classic import SplitMNIST
  24. from avalanche.models import MTSimpleMLP
  25. from avalanche.training.strategies import EWC
  26. from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics
  27. from avalanche.logging import InteractiveLogger
  28. from avalanche.training.plugins import EvaluationPlugin
  29. def main(args):
  30. # Config
  31. device = torch.device(f"cuda:{args.cuda}"
  32. if torch.cuda.is_available() and
  33. args.cuda >= 0 else "cpu")
  34. # model
  35. model = MTSimpleMLP()
  36. # CL Benchmark Creation
  37. scenario = SplitMNIST(n_experiences=5, return_task_id=True)
  38. train_stream = scenario.train_stream
  39. test_stream = scenario.test_stream
  40. # Prepare for training & testing
  41. optimizer = Adam(model.parameters(), lr=0.01)
  42. criterion = CrossEntropyLoss()
  43. # choose some metrics and evaluation method
  44. interactive_logger = InteractiveLogger()
  45. eval_plugin = EvaluationPlugin(
  46. accuracy_metrics(
  47. minibatch=False, epoch=True, experience=True, stream=True),
  48. forgetting_metrics(experience=True),
  49. loggers=[interactive_logger])
  50. # Choose a CL strategy
  51. strategy = EWC(
  52. model=model, optimizer=optimizer, criterion=criterion,
  53. train_mb_size=128, train_epochs=3, eval_mb_size=128, device=device,
  54. evaluator=eval_plugin,
  55. ewc_lambda=0.4)
  56. # train and test loop
  57. for train_task in train_stream:
  58. strategy.train(train_task)
  59. strategy.eval(test_stream)
  60. if __name__ == '__main__':
  61. parser = argparse.ArgumentParser()
  62. parser.add_argument('--cuda', type=int, default=0,
  63. help='Select zero-indexed cuda device. -1 to use CPU.')
  64. args = parser.parse_args()
  65. main(args)