synaptic_intelligence.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 26-01-2021 #
  7. # Author(s): Lorenzo Pellegrini #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This is a simple example on how to use the Synaptic Intelligence Plugin.
  13. """
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import argparse
  18. import torch
  19. from torch.nn import CrossEntropyLoss
  20. from torch.optim import Adam
  21. from torchvision import transforms
  22. from torchvision.transforms import ToTensor, Resize
  23. from avalanche.benchmarks import SplitCIFAR10
  24. from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics, \
  25. loss_metrics
  26. from avalanche.logging import InteractiveLogger
  27. from avalanche.logging.tensorboard_logger import TensorboardLogger
  28. from avalanche.models.mobilenetv1 import MobilenetV1
  29. from avalanche.training.plugins import EvaluationPlugin
  30. from avalanche.training.strategies.strategy_wrappers import SynapticIntelligence
  31. from avalanche.training.utils import adapt_classification_layer
  32. def main(args):
  33. # --- CONFIG
  34. device = torch.device(f"cuda:{args.cuda}"
  35. if torch.cuda.is_available() and
  36. args.cuda >= 0 else "cpu")
  37. # ---------
  38. # --- TRANSFORMATIONS
  39. train_transform = transforms.Compose([
  40. Resize(224),
  41. ToTensor(),
  42. transforms.Normalize((0.1307,), (0.3081,))
  43. ])
  44. test_transform = transforms.Compose([
  45. Resize(224),
  46. ToTensor(),
  47. transforms.Normalize((0.1307,), (0.3081,))
  48. ])
  49. # ---------
  50. # --- SCENARIO CREATION
  51. scenario = SplitCIFAR10(5, train_transform=train_transform,
  52. eval_transform=test_transform)
  53. # ---------
  54. # MODEL CREATION
  55. model = MobilenetV1()
  56. adapt_classification_layer(model, scenario.n_classes, bias=False)
  57. # DEFINE THE EVALUATION PLUGIN AND LOGGER
  58. my_logger = TensorboardLogger(
  59. tb_log_dir="logs", tb_log_exp_name="logging_example")
  60. # print to stdout
  61. interactive_logger = InteractiveLogger()
  62. evaluation_plugin = EvaluationPlugin(
  63. accuracy_metrics(
  64. minibatch=True, epoch=True, experience=True, stream=True),
  65. loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
  66. forgetting_metrics(experience=True),
  67. loggers=[my_logger, interactive_logger])
  68. # CREATE THE STRATEGY INSTANCE (NAIVE with the Synaptic Intelligence plugin)
  69. cl_strategy = SynapticIntelligence(
  70. model, Adam(model.parameters(), lr=0.001), CrossEntropyLoss(),
  71. si_lambda=0.0001, train_mb_size=128, train_epochs=4, eval_mb_size=128,
  72. device=device, evaluator=evaluation_plugin)
  73. # TRAINING LOOP
  74. print('Starting experiment...')
  75. results = []
  76. for experience in scenario.train_stream:
  77. print("Start of experience: ", experience.current_experience)
  78. print("Current Classes: ", experience.classes_in_this_experience)
  79. cl_strategy.train(experience)
  80. print('Training completed')
  81. print('Computing accuracy on the whole test set')
  82. results.append(cl_strategy.eval(scenario.test_stream))
  83. if __name__ == '__main__':
  84. parser = argparse.ArgumentParser()
  85. parser.add_argument('--cuda', type=int, default=0,
  86. help='Select zero-indexed cuda device. -1 to use CPU.')
  87. args = parser.parse_args()
  88. main(args)