wandb_logger.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 24-05-2020 #
  7. # Author(s): Andrea Cossu #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This is a simple example that shows how to use the
  13. WandB Logger
  14. """
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from os.path import expanduser
  19. import argparse
  20. import torch
  21. from torch.nn import CrossEntropyLoss
  22. from torch.optim import SGD
  23. from torchvision import transforms
  24. from torchvision.datasets import MNIST
  25. from torchvision.transforms import ToTensor, RandomCrop
  26. from avalanche.benchmarks import nc_benchmark
  27. from avalanche.logging import InteractiveLogger, WandBLogger
  28. from avalanche.training.plugins import EvaluationPlugin
  29. from avalanche.evaluation.metrics import forgetting_metrics, \
  30. accuracy_metrics, loss_metrics, cpu_usage_metrics, \
  31. timing_metrics, gpu_usage_metrics, ram_usage_metrics, disk_usage_metrics, \
  32. MAC_metrics, confusion_matrix_metrics
  33. from avalanche.models import SimpleMLP
  34. from avalanche.training.strategies import Naive
  35. def main(args):
  36. # --- CONFIG
  37. device = torch.device(f"cuda:{args.cuda}"
  38. if torch.cuda.is_available() and
  39. args.cuda >= 0 else "cpu")
  40. # ---------
  41. # --- TRANSFORMATIONS
  42. train_transform = transforms.Compose([
  43. RandomCrop(28, padding=4),
  44. ToTensor(),
  45. transforms.Normalize((0.1307,), (0.3081,))
  46. ])
  47. test_transform = transforms.Compose([
  48. ToTensor(),
  49. transforms.Normalize((0.1307,), (0.3081,))
  50. ])
  51. # ---------
  52. # --- SCENARIO CREATION
  53. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  54. train=True,
  55. download=True, transform=train_transform)
  56. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  57. train=False,
  58. download=True, transform=test_transform)
  59. scenario = nc_benchmark(
  60. mnist_train, mnist_test, 5, task_labels=False, seed=1234)
  61. # ---------
  62. # MODEL CREATION
  63. model = SimpleMLP(num_classes=scenario.n_classes)
  64. interactive_logger = InteractiveLogger()
  65. wandb_logger = WandBLogger(project_name=args.project, run_name=args.run,
  66. config=args)
  67. eval_plugin = EvaluationPlugin(
  68. accuracy_metrics(
  69. minibatch=True, epoch=True, epoch_running=True,
  70. experience=True, stream=True),
  71. loss_metrics(
  72. minibatch=True, epoch=True, epoch_running=True,
  73. experience=True, stream=True),
  74. forgetting_metrics(experience=True, stream=True),
  75. confusion_matrix_metrics(stream=True, wandb=True,
  76. class_names=[str(i) for i in range(10)]),
  77. cpu_usage_metrics(
  78. minibatch=True, epoch=True, experience=True, stream=True),
  79. timing_metrics(
  80. minibatch=True, epoch=True, experience=True, stream=True),
  81. ram_usage_metrics(
  82. every=0.5, minibatch=True, epoch=True, experience=True,
  83. stream=True),
  84. gpu_usage_metrics(
  85. args.cuda, every=0.5, minibatch=True, epoch=True,
  86. experience=True, stream=True),
  87. disk_usage_metrics(
  88. minibatch=True, epoch=True, experience=True, stream=True),
  89. MAC_metrics(
  90. minibatch=True, epoch=True, experience=True),
  91. loggers=[interactive_logger, wandb_logger]
  92. )
  93. # CREATE THE STRATEGY INSTANCE (NAIVE)
  94. cl_strategy = Naive(
  95. model, SGD(model.parameters(), lr=0.001, momentum=0.9),
  96. CrossEntropyLoss(), train_mb_size=100, train_epochs=4, eval_mb_size=100,
  97. device=device, evaluator=eval_plugin)
  98. # TRAINING LOOP
  99. print('Starting experiment...')
  100. results = []
  101. for experience in scenario.train_stream:
  102. print("Start of experience: ", experience.current_experience)
  103. print("Current Classes: ", experience.classes_in_this_experience)
  104. cl_strategy.train(experience)
  105. print('Training completed')
  106. print('Computing accuracy on the whole test set')
  107. results.append(cl_strategy.eval(scenario.test_stream))
  108. if __name__ == '__main__':
  109. parser = argparse.ArgumentParser()
  110. parser.add_argument('--cuda', type=int, default=0,
  111. help='Select zero-indexed cuda device. -1 to use CPU.')
  112. parser.add_argument('--run', type=str, help='Provide a run name for WandB')
  113. parser.add_argument('--project', type=str,
  114. help='Define the name of the WandB project')
  115. args = parser.parse_args()
  116. main(args)