pytorchcv_models.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 12-10-2020 #
  7. # Author(s): Eli Verwimp #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This example shows how to train models provided by pytorchcv with the rehearsal
  13. strategy.
  14. """
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from os.path import expanduser
  19. import argparse
  20. import torch
  21. from torch.nn import CrossEntropyLoss
  22. from torchvision import transforms
  23. from torchvision.datasets import CIFAR10
  24. from torchvision.transforms import ToTensor, RandomCrop
  25. import torch.optim.lr_scheduler
  26. from avalanche.benchmarks import nc_benchmark
  27. from avalanche.models import pytorchcv_wrapper
  28. from avalanche.training.strategies import Naive
  29. from avalanche.training.plugins import ReplayPlugin
  30. from avalanche.evaluation.metrics import forgetting_metrics, \
  31. accuracy_metrics, loss_metrics
  32. from avalanche.logging import InteractiveLogger
  33. from avalanche.training.plugins import EvaluationPlugin
  34. def main(args):
  35. # Model getter: specify dataset and depth of the network.
  36. model = pytorchcv_wrapper.resnet('cifar10', depth=20, pretrained=False)
  37. # Or get a more specific model. E.g. wide resnet, with depth 40 and growth
  38. # factor 8 for Cifar 10.
  39. # model = pytorchcv_wrapper.get_model("wrn40_8_cifar10", pretrained=False)
  40. # --- CONFIG
  41. device = torch.device(f"cuda:{args.cuda}"
  42. if torch.cuda.is_available() and
  43. args.cuda >= 0 else "cpu")
  44. device = "cpu"
  45. # --- TRANSFORMATIONS
  46. transform = transforms.Compose([
  47. ToTensor(),
  48. transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))
  49. ])
  50. # --- SCENARIO CREATION
  51. cifar_train = CIFAR10(root=expanduser("~") + "/.avalanche/data/cifar10/",
  52. train=True, download=True, transform=transform)
  53. cifar_test = CIFAR10(root=expanduser("~") + "/.avalanche/data/cifar10/",
  54. train=False, download=True, transform=transform)
  55. scenario = nc_benchmark(
  56. cifar_train, cifar_test, 5, task_labels=False, seed=1234,
  57. fixed_class_order=[i for i in range(10)])
  58. # choose some metrics and evaluation method
  59. interactive_logger = InteractiveLogger()
  60. eval_plugin = EvaluationPlugin(
  61. accuracy_metrics(
  62. minibatch=True, epoch=True, experience=True, stream=True),
  63. loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
  64. forgetting_metrics(experience=True),
  65. loggers=[interactive_logger])
  66. # CREATE THE STRATEGY INSTANCE (Naive, with Replay)
  67. cl_strategy = Naive(model, torch.optim.SGD(model.parameters(), lr=0.01),
  68. CrossEntropyLoss(),
  69. train_mb_size=100, train_epochs=1, eval_mb_size=100,
  70. device=device,
  71. plugins=[ReplayPlugin(mem_size=1000)],
  72. evaluator=eval_plugin
  73. )
  74. # TRAINING LOOP
  75. print('Starting experiment...')
  76. results = []
  77. for experience in scenario.train_stream:
  78. print("Start of experience ", experience.current_experience)
  79. cl_strategy.train(experience)
  80. print('Training completed')
  81. print('Computing accuracy on the whole test set')
  82. results.append(cl_strategy.eval(scenario.test_stream))
  83. if __name__ == '__main__':
  84. parser = argparse.ArgumentParser()
  85. parser.add_argument('--cuda', type=int, default=0,
  86. help='Select zero-indexed cuda device. -1 to use CPU.')
  87. args = parser.parse_args()
  88. main(args)