123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import torch
- import argparse
- from avalanche.benchmarks import SplitMNIST
- from avalanche.training.strategies import LwF
- from avalanche.models import SimpleMLP
- from avalanche.evaluation.metrics import forgetting_metrics, \
- accuracy_metrics, loss_metrics
- from avalanche.logging import InteractiveLogger
- from avalanche.training.plugins import EvaluationPlugin
- """
- This example tests Learning without Forgetting (LwF) on Split MNIST.
- The performance with default arguments should give an average accuracy
- of about 73%.
- """
- def main(args):
- model = SimpleMLP(hidden_size=args.hs)
- optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
- criterion = torch.nn.CrossEntropyLoss()
- # check if selected GPU is available or use CPU
- assert args.cuda == -1 or args.cuda >= 0, "cuda must be -1 or >= 0."
- device = torch.device(f"cuda:{args.cuda}"
- if torch.cuda.is_available() and
- args.cuda >= 0 else "cpu")
- print(f'Using device: {device}')
- # create split scenario
- scenario = SplitMNIST(n_experiences=5, return_task_id=False)
- interactive_logger = InteractiveLogger()
- eval_plugin = EvaluationPlugin(
- accuracy_metrics(
- minibatch=True, epoch=True, experience=True, stream=True),
- loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
- forgetting_metrics(experience=True),
- loggers=[interactive_logger])
- # create strategy
- assert len(args.lwf_alpha) == 1 or len(args.lwf_alpha) == 5,\
- 'Alpha must be a non-empty list.'
- lwf_alpha = args.lwf_alpha[0] if len(args.lwf_alpha) == 1 \
- else args.lwf_alpha
- strategy = LwF(model, optimizer, criterion, alpha=lwf_alpha,
- temperature=args.softmax_temperature,
- train_epochs=args.epochs, device=device,
- train_mb_size=args.minibatch_size, evaluator=eval_plugin)
- # train on the selected scenario with the chosen strategy
- print('Starting experiment...')
- results = []
- for train_batch_info in scenario.train_stream:
- print("Start training on experience ",
- train_batch_info.current_experience)
- strategy.train(train_batch_info, num_workers=0)
- print("End training on experience ",
- train_batch_info.current_experience)
- print('Computing accuracy on the test set')
- results.append(strategy.eval(scenario.test_stream[:]))
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--lwf_alpha', nargs='+', type=float,
- default=[0, 0.5, 1.333, 2.25, 3.2],
- help='Penalty hyperparameter for LwF. It can be either'
- 'a list with multiple elements (one alpha per '
- 'experience) or a list of one element (same alpha '
- 'for all experiences).')
- parser.add_argument('--softmax_temperature', type=float, default=1,
- help='Temperature for softmax used in distillation')
- parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.')
- parser.add_argument('--hs', type=int, default=256, help='MLP hidden size.')
- parser.add_argument('--epochs', type=int, default=10,
- help='Number of training epochs.')
- parser.add_argument('--minibatch_size', type=int, default=128,
- help='Minibatch size.')
- parser.add_argument('--cuda', type=int, default=0,
- help='Specify GPU id to use. Use CPU if -1.')
- args = parser.parse_args()
- main(args)
|