task_metrics.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 24-05-2020 #
  7. # Author(s): Andrea Cossu #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This is a simple example on how to use the Evaluation Plugin with metrics
  13. returning values for different tasks.
  14. """
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from os.path import expanduser
  19. import argparse
  20. import torch
  21. from torch.nn import CrossEntropyLoss
  22. from torch.optim import SGD
  23. from avalanche.benchmarks.generators.benchmark_generators import \
  24. create_multi_dataset_generic_benchmark
  25. from avalanche.benchmarks.utils import AvalancheTensorDataset
  26. from avalanche.evaluation.metrics import forgetting_metrics, \
  27. accuracy_metrics, loss_metrics, cpu_usage_metrics, timing_metrics, \
  28. gpu_usage_metrics, ram_usage_metrics, disk_usage_metrics, MAC_metrics, \
  29. bwt_metrics
  30. from avalanche.models import SimpleMLP
  31. from avalanche.logging import InteractiveLogger, TextLogger, CSVLogger
  32. from avalanche.training.plugins import EvaluationPlugin
  33. from avalanche.training.strategies import Naive
  34. def main(args):
  35. # --- CONFIG
  36. device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() and
  37. args.cuda >= 0 else "cpu")
  38. # ---------
  39. tr_ds = [AvalancheTensorDataset(
  40. torch.randn(10, 3), torch.randint(0, 3, (10,)).tolist(),
  41. task_labels=torch.randint(0, 5, (10,)).tolist()) for _ in range(3)]
  42. ts_ds = [AvalancheTensorDataset(
  43. torch.randn(10, 3), torch.randint(0, 3, (10,)).tolist(),
  44. task_labels=torch.randint(0, 5, (10,)).tolist()) for _ in range(3)]
  45. scenario = create_multi_dataset_generic_benchmark(
  46. train_datasets=tr_ds, test_datasets=ts_ds)
  47. # ---------
  48. # MODEL CREATION
  49. model = SimpleMLP(num_classes=3, input_size=3)
  50. # DEFINE THE EVALUATION PLUGIN AND LOGGER
  51. # The evaluation plugin manages the metrics computation.
  52. # It takes as argument a list of metrics and a list of loggers.
  53. # The evaluation plugin calls the loggers to serialize the metrics
  54. # and save them in persistent memory or print them in the standard output.
  55. # log to text file
  56. text_logger = TextLogger(open('log.txt', 'a'))
  57. # print to stdout
  58. interactive_logger = InteractiveLogger()
  59. csv_logger = CSVLogger()
  60. eval_plugin = EvaluationPlugin(
  61. accuracy_metrics(
  62. minibatch=True, epoch=True, epoch_running=True, experience=True,
  63. stream=True),
  64. loss_metrics(minibatch=True, epoch=True, epoch_running=True,
  65. experience=True, stream=True),
  66. forgetting_metrics(experience=True, stream=True),
  67. bwt_metrics(experience=True, stream=True),
  68. cpu_usage_metrics(
  69. minibatch=True, epoch=True, epoch_running=True,
  70. experience=True, stream=True),
  71. timing_metrics(
  72. minibatch=True, epoch=True, epoch_running=True,
  73. experience=True, stream=True),
  74. ram_usage_metrics(
  75. every=0.5, minibatch=True, epoch=True,
  76. experience=True, stream=True),
  77. gpu_usage_metrics(
  78. args.cuda, every=0.5, minibatch=True, epoch=True,
  79. experience=True, stream=True),
  80. disk_usage_metrics(
  81. minibatch=True, epoch=True, experience=True, stream=True),
  82. MAC_metrics(
  83. minibatch=True, epoch=True, experience=True),
  84. loggers=[interactive_logger, text_logger, csv_logger],
  85. collect_all=True) # collect all metrics (set to True by default)
  86. # CREATE THE STRATEGY INSTANCE (NAIVE)
  87. cl_strategy = Naive(
  88. model, SGD(model.parameters(), lr=0.001, momentum=0.9),
  89. CrossEntropyLoss(), train_mb_size=500, train_epochs=1, eval_mb_size=100,
  90. device=device, evaluator=eval_plugin, eval_every=1)
  91. # TRAINING LOOP
  92. print('Starting experiment...')
  93. results = []
  94. for i, experience in enumerate(scenario.train_stream):
  95. print("Start of experience: ", experience.current_experience)
  96. print("Current Classes: ", experience.classes_in_this_experience)
  97. # train returns a dictionary containing last recorded value
  98. # for each metric.
  99. res = cl_strategy.train(experience,
  100. eval_streams=[scenario.test_stream])
  101. print('Training completed')
  102. print('Computing accuracy on the whole test set')
  103. # test returns a dictionary with the last metric collected during
  104. # evaluation on that stream
  105. results.append(cl_strategy.eval(scenario.test_stream))
  106. print(f"Test metrics:\n{results}")
  107. # Dict with all the metric curves,
  108. # only available when `collect_all` is True.
  109. # Each entry is a (x, metric value) tuple.
  110. # You can use this dictionary to manipulate the
  111. # metrics without avalanche.
  112. all_metrics = cl_strategy.evaluator.get_all_metrics()
  113. print(f"Stored metrics: {list(all_metrics.keys())}")
  114. if __name__ == '__main__':
  115. parser = argparse.ArgumentParser()
  116. parser.add_argument('--cuda', type=int, default=0,
  117. help='Select zero-indexed cuda device. -1 to use CPU.')
  118. args = parser.parse_args()
  119. main(args)