joint_training.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 20-11-2020 #
  7. # Author(s): Vincenzo Lomonaco #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This is a simple example to show how a simple "offline" upper bound can be
  13. computed. This is useful to see what's the maximum accuracy a model can get
  14. without the hindering of learning continually. This is often referred to as
  15. "cumulative", "joint-training" or "offline" upper bound.
  16. """
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import argparse
  21. import torch
  22. from torch.nn import CrossEntropyLoss
  23. from torch.optim import SGD
  24. from avalanche.benchmarks.classic import PermutedMNIST
  25. from avalanche.models import SimpleMLP
  26. from avalanche.training.strategies import JointTraining
  27. def main(args):
  28. # Config
  29. device = torch.device(f"cuda:{args.cuda}"
  30. if torch.cuda.is_available() and
  31. args.cuda >= 0 else "cpu")
  32. # model
  33. model = SimpleMLP(num_classes=10)
  34. # CL Benchmark Creation
  35. perm_mnist = PermutedMNIST(n_experiences=5)
  36. train_stream = perm_mnist.train_stream
  37. test_stream = perm_mnist.test_stream
  38. # Prepare for training & testing
  39. optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
  40. criterion = CrossEntropyLoss()
  41. # Joint training strategy
  42. joint_train = JointTraining(
  43. model, optimizer, criterion, train_mb_size=32, train_epochs=1,
  44. eval_mb_size=32, device=device)
  45. # train and test loop
  46. results = []
  47. print("Starting training.")
  48. # Differently from other avalanche strategies, you NEED to call train
  49. # on the entire stream.
  50. joint_train.train(train_stream)
  51. results.append(joint_train.eval(test_stream))
  52. if __name__ == '__main__':
  53. parser = argparse.ArgumentParser()
  54. parser.add_argument('--cuda', type=int, default=0,
  55. help='Select zero-indexed cuda device. -1 to use CPU.')
  56. args = parser.parse_args()
  57. main(args)