deep_slda.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 06-04-2021 #
  7. # Author(s): Tyler Hayes #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This is a simple example on how to use the Deep SLDA strategy.
  13. """
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import argparse
  18. import torch
  19. import warnings
  20. from torchvision import transforms
  21. from avalanche.training.plugins import EvaluationPlugin
  22. from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics, \
  23. forgetting_metrics
  24. from avalanche.logging import InteractiveLogger
  25. from avalanche.benchmarks.classic import CORe50
  26. from avalanche.training.strategies.deep_slda import StreamingLDA
  27. from avalanche.models import SLDAResNetModel
  28. def main(args):
  29. # Device config
  30. device = torch.device(f"cuda:{args.cuda}"
  31. if torch.cuda.is_available() and
  32. args.cuda >= 0 else "cpu")
  33. print('device ', device)
  34. # ---------
  35. # --- TRANSFORMATIONS
  36. _mu = [0.485, 0.456, 0.406] # imagenet normalization
  37. _std = [0.229, 0.224, 0.225]
  38. transform = transforms.Compose([
  39. transforms.Resize((224, 224)),
  40. transforms.ToTensor(),
  41. transforms.Normalize(mean=_mu,
  42. std=_std)
  43. ])
  44. # ---------
  45. # --- BENCHMARK CREATION
  46. benchmark = CORe50(scenario=args.scenario, train_transform=transform,
  47. eval_transform=transform)
  48. # ---------
  49. eval_plugin = EvaluationPlugin(
  50. loss_metrics(epoch=True, experience=True, stream=True),
  51. accuracy_metrics(epoch=True, experience=True, stream=True),
  52. forgetting_metrics(experience=True, stream=True),
  53. loggers=[InteractiveLogger()]
  54. )
  55. criterion = torch.nn.CrossEntropyLoss()
  56. model = SLDAResNetModel(device=device, arch='resnet18',
  57. imagenet_pretrained=args.imagenet_pretrained)
  58. # CREATE THE STRATEGY INSTANCE
  59. cl_strategy = StreamingLDA(model, criterion,
  60. args.feature_size, args.n_classes,
  61. eval_mb_size=args.batch_size,
  62. train_mb_size=args.batch_size,
  63. train_epochs=1,
  64. shrinkage_param=args.shrinkage,
  65. streaming_update_sigma=args.plastic_cov,
  66. device=device, evaluator=eval_plugin)
  67. warnings.warn(
  68. "The Deep SLDA example is not perfectly aligned with "
  69. "the paper implementation since it does not use a base "
  70. "initialization phase and instead starts streming from "
  71. "pre-trained weights.")
  72. # TRAINING LOOP
  73. print('Starting experiment...')
  74. for i, exp in enumerate(benchmark.train_stream):
  75. # fit SLDA model to batch (one sample at a time)
  76. cl_strategy.train(exp)
  77. # evaluate model on test data
  78. cl_strategy.eval(benchmark.test_stream)
  79. if __name__ == '__main__':
  80. parser = argparse.ArgumentParser('SLDA Example with ResNet-18 on CORe50')
  81. parser.add_argument('--cuda', type=int, default=0,
  82. help='Select zero-indexed cuda device. -1 to use CPU.')
  83. parser.add_argument('--n_classes', type=int, default=50)
  84. parser.add_argument('--scenario', type=str, default="nc",
  85. choices=['ni', 'nc', 'nic', 'nicv2_79', 'nicv2_196',
  86. 'nicv2_391'])
  87. # deep slda model parameters
  88. parser.add_argument('--imagenet_pretrained', type=bool,
  89. default=True) # initialize backbone with
  90. # imagenet pre-trained weights
  91. parser.add_argument('--feature_size', type=int,
  92. default=512) # feature size before output layer
  93. # (512 for resnet-18)
  94. parser.add_argument('--shrinkage', type=float,
  95. default=1e-4) # shrinkage value
  96. parser.add_argument('--plastic_cov', type=bool,
  97. default=True) # plastic covariance matrix
  98. parser.add_argument('--batch_size', type=int, default=512)
  99. args = parser.parse_args()
  100. main(args)