eval_plugin.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 24-05-2020 #
  7. # Author(s): Lorenzo Pellegrini #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This is a simple example on how to use the Evaluation Plugin.
  13. """
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from os.path import expanduser
  18. import argparse
  19. import torch
  20. from torch.nn import CrossEntropyLoss
  21. from torch.optim import SGD
  22. from torchvision import transforms
  23. from torchvision.datasets import MNIST
  24. from torchvision.transforms import ToTensor, RandomCrop
  25. from avalanche.benchmarks import nc_benchmark
  26. from avalanche.evaluation.metrics import forgetting_metrics, \
  27. accuracy_metrics, loss_metrics, cpu_usage_metrics, timing_metrics, \
  28. gpu_usage_metrics, ram_usage_metrics, disk_usage_metrics, MAC_metrics, \
  29. bwt_metrics, forward_transfer_metrics
  30. from avalanche.models import SimpleMLP
  31. from avalanche.logging import InteractiveLogger, TextLogger, CSVLogger
  32. from avalanche.training.plugins import EvaluationPlugin
  33. from avalanche.training.strategies import Naive
  34. def main(args):
  35. # --- CONFIG
  36. device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() and
  37. args.cuda >= 0 else "cpu")
  38. # ---------
  39. # --- TRANSFORMATIONS
  40. train_transform = transforms.Compose([
  41. RandomCrop(28, padding=4),
  42. ToTensor(),
  43. transforms.Normalize((0.1307,), (0.3081,))
  44. ])
  45. test_transform = transforms.Compose([
  46. ToTensor(),
  47. transforms.Normalize((0.1307,), (0.3081,))
  48. ])
  49. # ---------
  50. # --- SCENARIO CREATION
  51. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  52. train=True, download=True, transform=train_transform)
  53. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  54. train=False, download=True, transform=test_transform)
  55. scenario = nc_benchmark(
  56. mnist_train, mnist_test, 5, task_labels=False, seed=1234)
  57. # ---------
  58. # MODEL CREATION
  59. model = SimpleMLP(num_classes=scenario.n_classes)
  60. # DEFINE THE EVALUATION PLUGIN AND LOGGER
  61. # The evaluation plugin manages the metrics computation.
  62. # It takes as argument a list of metrics and a list of loggers.
  63. # The evaluation plugin calls the loggers to serialize the metrics
  64. # and save them in persistent memory or print them in the standard output.
  65. # log to text file
  66. text_logger = TextLogger(open('log.txt', 'a'))
  67. # print to stdout
  68. interactive_logger = InteractiveLogger()
  69. csv_logger = CSVLogger()
  70. eval_plugin = EvaluationPlugin(
  71. accuracy_metrics(
  72. minibatch=True, epoch=True, epoch_running=True, experience=True,
  73. stream=True),
  74. loss_metrics(minibatch=True, epoch=True, epoch_running=True,
  75. experience=True, stream=True),
  76. forgetting_metrics(experience=True, stream=True),
  77. bwt_metrics(experience=True, stream=True),
  78. forward_transfer_metrics(experience=True, stream=True),
  79. cpu_usage_metrics(
  80. minibatch=True, epoch=True, epoch_running=True,
  81. experience=True, stream=True),
  82. timing_metrics(
  83. minibatch=True, epoch=True, epoch_running=True,
  84. experience=True, stream=True),
  85. ram_usage_metrics(
  86. every=0.5, minibatch=True, epoch=True,
  87. experience=True, stream=True),
  88. gpu_usage_metrics(
  89. args.cuda, every=0.5, minibatch=True, epoch=True,
  90. experience=True, stream=True),
  91. disk_usage_metrics(
  92. minibatch=True, epoch=True, experience=True, stream=True),
  93. MAC_metrics(
  94. minibatch=True, epoch=True, experience=True),
  95. loggers=[interactive_logger, text_logger, csv_logger],
  96. collect_all=True) # collect all metrics (set to True by default)
  97. # CREATE THE STRATEGY INSTANCE (NAIVE)
  98. cl_strategy = Naive(
  99. model, SGD(model.parameters(), lr=0.001, momentum=0.9),
  100. CrossEntropyLoss(), train_mb_size=500, train_epochs=1, eval_mb_size=100,
  101. device=device, evaluator=eval_plugin, eval_every=1)
  102. # TRAINING LOOP
  103. print('Starting experiment...')
  104. results = []
  105. for i, experience in enumerate(scenario.train_stream):
  106. print("Start of experience: ", experience.current_experience)
  107. print("Current Classes: ", experience.classes_in_this_experience)
  108. # train returns a dictionary containing last recorded value
  109. # for each metric.
  110. res = cl_strategy.train(experience,
  111. eval_streams=[scenario.test_stream])
  112. print('Training completed')
  113. print('Computing accuracy on the whole test set')
  114. # test returns a dictionary with the last metric collected during
  115. # evaluation on that stream
  116. results.append(cl_strategy.eval(scenario.test_stream))
  117. print(f"Test metrics:\n{results}")
  118. # Dict with all the metric curves,
  119. # only available when `collect_all` is True.
  120. # Each entry is a (x, metric value) tuple.
  121. # You can use this dictionary to manipulate the
  122. # metrics without avalanche.
  123. all_metrics = cl_strategy.evaluator.get_all_metrics()
  124. print(f"Stored metrics: {list(all_metrics.keys())}")
  125. if __name__ == '__main__':
  126. parser = argparse.ArgumentParser()
  127. parser.add_argument('--cuda', type=int, default=0,
  128. help='Select zero-indexed cuda device. -1 to use CPU.')
  129. args = parser.parse_args()
  130. main(args)