confusion_matrix.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 24-05-2020 #
  7. # Author(s): Andrea Cossu #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This example shows how to produce confusion matrix during training and
  13. evaluation.
  14. """
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from os.path import expanduser
  19. import argparse
  20. import torch
  21. from torch.nn import CrossEntropyLoss
  22. from torch.optim import SGD
  23. from torchvision import transforms
  24. from torchvision.datasets import MNIST
  25. from torchvision.transforms import ToTensor, RandomCrop
  26. from avalanche.benchmarks import nc_benchmark
  27. from avalanche.models import SimpleMLP
  28. from avalanche.training.strategies import Naive
  29. from avalanche.training.plugins import EvaluationPlugin, ReplayPlugin
  30. from avalanche.evaluation.metrics import confusion_matrix_metrics, \
  31. accuracy_metrics, loss_metrics
  32. from avalanche.logging import InteractiveLogger
  33. def main(args):
  34. # --- CONFIG
  35. device = torch.device(f"cuda:{args.cuda}"
  36. if torch.cuda.is_available() and
  37. args.cuda >= 0 else "cpu")
  38. # ---------
  39. # --- TRANSFORMATIONS
  40. train_transform = transforms.Compose([
  41. RandomCrop(28, padding=4),
  42. ToTensor(),
  43. transforms.Normalize((0.1307,), (0.3081,))
  44. ])
  45. test_transform = transforms.Compose([
  46. ToTensor(),
  47. transforms.Normalize((0.1307,), (0.3081,))
  48. ])
  49. # ---------
  50. # --- SCENARIO CREATION
  51. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  52. train=True, download=True, transform=train_transform)
  53. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  54. train=False, download=True, transform=test_transform)
  55. scenario = nc_benchmark(
  56. mnist_train, mnist_test, 5, task_labels=False, seed=1234)
  57. # ---------
  58. # MODEL CREATION
  59. model = SimpleMLP(num_classes=scenario.n_classes)
  60. eval_plugin = EvaluationPlugin(
  61. accuracy_metrics(epoch=True, experience=True, stream=True),
  62. loss_metrics(epoch=True, experience=True, stream=True),
  63. # save image should be False to appropriately view
  64. # results in Interactive Logger.
  65. # a tensor will be printed
  66. confusion_matrix_metrics(save_image=False, normalize='all',
  67. stream=True),
  68. loggers=InteractiveLogger()
  69. )
  70. # CREATE THE STRATEGY INSTANCE (NAIVE)
  71. cl_strategy = Naive(
  72. model, SGD(model.parameters(), lr=0.001, momentum=0.9),
  73. CrossEntropyLoss(), train_mb_size=100, train_epochs=4, eval_mb_size=100,
  74. device=device, evaluator=eval_plugin, plugins=[ReplayPlugin(5000)])
  75. # TRAINING LOOP
  76. print('Starting experiment...')
  77. results = []
  78. for experience in scenario.train_stream:
  79. print("Start of experience: ", experience.current_experience)
  80. print("Current Classes: ", experience.classes_in_this_experience)
  81. cl_strategy.train(experience)
  82. print('Training completed')
  83. print('Computing accuracy on the whole test set')
  84. results.append(cl_strategy.eval(scenario.test_stream))
  85. if __name__ == '__main__':
  86. parser = argparse.ArgumentParser()
  87. parser.add_argument('--cuda', type=int, default=0,
  88. help='Select zero-indexed cuda device. -1 to use CPU.')
  89. args = parser.parse_args()
  90. main(args)