ar1.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 08-02-2021 #
  7. # Author(s): Lorenzo Pellegrini #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This is a simple example on how to use the AR1 strategy.
  13. """
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import argparse
  18. import torch
  19. from torch.nn import CrossEntropyLoss
  20. from torchvision import transforms
  21. from torchvision.transforms import ToTensor, Resize
  22. from avalanche.benchmarks import SplitCIFAR10
  23. from avalanche.training.strategies.ar1 import AR1
  24. def main(args):
  25. # Device config
  26. device = torch.device(f"cuda:{args.cuda}"
  27. if torch.cuda.is_available() and
  28. args.cuda >= 0 else "cpu")
  29. # ---------
  30. # --- TRANSFORMATIONS
  31. train_transform = transforms.Compose([
  32. Resize(224),
  33. ToTensor(),
  34. transforms.Normalize((0.1307,), (0.3081,))
  35. ])
  36. test_transform = transforms.Compose([
  37. Resize(224),
  38. ToTensor(),
  39. transforms.Normalize((0.1307,), (0.3081,))
  40. ])
  41. # ---------
  42. # --- SCENARIO CREATION
  43. scenario = SplitCIFAR10(5, train_transform=train_transform,
  44. eval_transform=test_transform)
  45. # ---------
  46. # CREATE THE STRATEGY INSTANCE
  47. cl_strategy = AR1(criterion=CrossEntropyLoss(), device=device)
  48. # TRAINING LOOP
  49. print('Starting experiment...')
  50. results = []
  51. for experience in scenario.train_stream:
  52. print("Start of experience: ", experience.current_experience)
  53. print("Current Classes: ", experience.classes_in_this_experience)
  54. cl_strategy.train(experience, num_workers=0)
  55. print('Training completed')
  56. print('Computing accuracy on the whole test set')
  57. results.append(cl_strategy.eval(scenario.test_stream, num_workers=0))
  58. if __name__ == '__main__':
  59. parser = argparse.ArgumentParser()
  60. parser.add_argument('--cuda', type=int, default=0,
  61. help='Select zero-indexed cuda device. -1 to use CPU.')
  62. args = parser.parse_args()
  63. main(args)