replay.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 12-10-2020 #
  7. # Author(s): Vincenzo Lomonaco #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This is a simple example on how to use the Replay strategy.
  13. """
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from os.path import expanduser
  18. import argparse
  19. import torch
  20. from torch.nn import CrossEntropyLoss
  21. from torchvision import transforms
  22. from torchvision.datasets import MNIST
  23. from torchvision.transforms import ToTensor, RandomCrop
  24. import torch.optim.lr_scheduler
  25. from avalanche.benchmarks import nc_benchmark
  26. from avalanche.models import SimpleMLP
  27. from avalanche.training.strategies import Naive
  28. from avalanche.training.plugins import ReplayPlugin
  29. from avalanche.evaluation.metrics import forgetting_metrics, \
  30. accuracy_metrics, loss_metrics
  31. from avalanche.logging import InteractiveLogger
  32. from avalanche.training.plugins import EvaluationPlugin
  33. def main(args):
  34. # --- CONFIG
  35. device = torch.device(f"cuda:{args.cuda}"
  36. if torch.cuda.is_available() and
  37. args.cuda >= 0 else "cpu")
  38. n_batches = 5
  39. # ---------
  40. # --- TRANSFORMATIONS
  41. train_transform = transforms.Compose([
  42. RandomCrop(28, padding=4),
  43. ToTensor(),
  44. transforms.Normalize((0.1307,), (0.3081,))
  45. ])
  46. test_transform = transforms.Compose([
  47. ToTensor(),
  48. transforms.Normalize((0.1307,), (0.3081,))
  49. ])
  50. # ---------
  51. # --- SCENARIO CREATION
  52. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  53. train=True, download=True, transform=train_transform)
  54. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  55. train=False, download=True, transform=test_transform)
  56. scenario = nc_benchmark(
  57. mnist_train, mnist_test, n_batches, task_labels=False, seed=1234)
  58. # ---------
  59. # MODEL CREATION
  60. model = SimpleMLP(num_classes=scenario.n_classes)
  61. # choose some metrics and evaluation method
  62. interactive_logger = InteractiveLogger()
  63. eval_plugin = EvaluationPlugin(
  64. accuracy_metrics(
  65. minibatch=True, epoch=True, experience=True, stream=True),
  66. loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
  67. forgetting_metrics(experience=True),
  68. loggers=[interactive_logger])
  69. # CREATE THE STRATEGY INSTANCE (NAIVE)
  70. cl_strategy = Naive(model, torch.optim.Adam(model.parameters(), lr=0.001),
  71. CrossEntropyLoss(),
  72. train_mb_size=100, train_epochs=4, eval_mb_size=100,
  73. device=device,
  74. plugins=[ReplayPlugin(mem_size=10000)],
  75. evaluator=eval_plugin
  76. )
  77. # TRAINING LOOP
  78. print('Starting experiment...')
  79. results = []
  80. for experience in scenario.train_stream:
  81. print("Start of experience ", experience.current_experience)
  82. cl_strategy.train(experience)
  83. print('Training completed')
  84. print('Computing accuracy on the whole test set')
  85. results.append(cl_strategy.eval(scenario.test_stream))
  86. if __name__ == '__main__':
  87. parser = argparse.ArgumentParser()
  88. parser.add_argument('--cuda', type=int, default=0,
  89. help='Select zero-indexed cuda device. -1 to use CPU.')
  90. args = parser.parse_args()
  91. main(args)