getting_started.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 24-05-2020 #
  7. # Author(s): Lorenzo Pellegrini #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This is a simple example on how to use the new strategy API.
  13. """
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from os.path import expanduser
  18. import argparse
  19. import torch
  20. from torch.nn import CrossEntropyLoss
  21. from torch.optim import SGD
  22. from torchvision import transforms
  23. from torchvision.datasets import MNIST
  24. from torchvision.transforms import ToTensor, RandomCrop
  25. from avalanche.benchmarks import nc_benchmark
  26. from avalanche.models import SimpleMLP
  27. from avalanche.training.strategies import Naive
  28. def main(args):
  29. # --- CONFIG
  30. device = torch.device(f"cuda:{args.cuda}"
  31. if torch.cuda.is_available() and
  32. args.cuda >= 0 else "cpu")
  33. # ---------
  34. # --- TRANSFORMATIONS
  35. train_transform = transforms.Compose([
  36. RandomCrop(28, padding=4),
  37. ToTensor(),
  38. transforms.Normalize((0.1307,), (0.3081,))
  39. ])
  40. test_transform = transforms.Compose([
  41. ToTensor(),
  42. transforms.Normalize((0.1307,), (0.3081,))
  43. ])
  44. # ---------
  45. # --- SCENARIO CREATION
  46. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  47. train=True, download=True, transform=train_transform)
  48. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  49. train=False, download=True, transform=test_transform)
  50. scenario = nc_benchmark(
  51. mnist_train, mnist_test, 5, task_labels=False, seed=1234)
  52. # ---------
  53. # MODEL CREATION
  54. model = SimpleMLP(num_classes=scenario.n_classes)
  55. # CREATE THE STRATEGY INSTANCE (NAIVE)
  56. cl_strategy = Naive(
  57. model, SGD(model.parameters(), lr=0.001, momentum=0.9),
  58. CrossEntropyLoss(), train_mb_size=100, train_epochs=4, eval_mb_size=100,
  59. device=device)
  60. # TRAINING LOOP
  61. print('Starting experiment...')
  62. results = []
  63. for experience in scenario.train_stream:
  64. print("Start of experience: ", experience.current_experience)
  65. print("Current Classes: ", experience.classes_in_this_experience)
  66. cl_strategy.train(experience)
  67. print('Training completed')
  68. print('Computing accuracy on the whole test set')
  69. results.append(cl_strategy.eval(scenario.test_stream))
  70. if __name__ == '__main__':
  71. parser = argparse.ArgumentParser()
  72. parser.add_argument('--cuda', type=int, default=0,
  73. help='Select zero-indexed cuda device. -1 to use CPU.')
  74. args = parser.parse_args()
  75. main(args)