123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- import torch
- import argparse
- from avalanche.benchmarks import PermutedMNIST, SplitMNIST
- from avalanche.training.strategies import GEM, AGEM
- from avalanche.models import SimpleMLP
- from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics, \
- loss_metrics
- from avalanche.logging import InteractiveLogger
- from avalanche.training.plugins import EvaluationPlugin
- """
- This example tests both GEM and A-GEM on Split MNIST and Permuted MNIST.
- GEM is a streaming strategy, that is it uses only 1 training epochs.
- A-GEM may use a larger number of epochs.
- Both GEM and A-GEM work with small mini batches (usually with 10 patterns).
- Warning1: This implementation of GEM and A-GEM does not use task vectors.
- Warning2: GEM is much slower than A-GEM.
- Results (learning rate is always 0.1):
- GEM-PMNIST (5 experiences):
- Hidden size 512. 1 training epoch. 512 patterns per experience, 0.5 memory
- strength. Average Accuracy over all experiences at the end of training on the
- last experience: 92.6%
- GEM-SMNIST:
- Patterns per experience: 256, Memory strength: 0.5, hidden size: 256
- Average Accuracy over all experiences at the end of training on the last
- experience: 93.3%
- AGEM-PMNIST (5 experiences):
- Patterns per experience = sample size: 256. 256 hidden size, 1 training epoch.
- Average Accuracy over all experiences at the end of training on the last
- experience: 83.5%
- AGEM-SMNIST:
- Patterns per experience = sample size: 256, 512, 1024. Performance on previous
- tasks remains very bad in terms of forgetting. Training epochs do not change
- result.
- Hidden size 256.
- Results for 1024 patterns per experience and sample size, 1 training epoch.
- Average Accuracy over all experiences at the end of training on the last
- experience: 67.0%
- """
- def main(args):
- model = SimpleMLP(hidden_size=args.hs)
- optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
- criterion = torch.nn.CrossEntropyLoss()
- # check if selected GPU is available or use CPU
- assert args.cuda == -1 or args.cuda >= 0, "cuda must be -1 or >= 0."
- device = torch.device(f"cuda:{args.cuda}"
- if torch.cuda.is_available() and
- args.cuda >= 0 else "cpu")
- print(f'Using device: {device}')
- # create scenario
- if args.scenario == 'pmnist':
- scenario = PermutedMNIST(n_experiences=args.permutations)
- elif args.scenario == 'smnist':
- scenario = SplitMNIST(n_experiences=5, return_task_id=False)
- else:
- raise ValueError("Wrong scenario name. Allowed pmnist, smnist.")
- # choose some metrics and evaluation method
- interactive_logger = InteractiveLogger()
- eval_plugin = EvaluationPlugin(
- accuracy_metrics(minibatch=True, epoch=True,
- experience=True, stream=True),
- loss_metrics(minibatch=True, epoch=True,
- experience=True, stream=True),
- forgetting_metrics(experience=True),
- loggers=[interactive_logger])
- # create strategy
- if args.strategy == 'gem':
- strategy = GEM(model, optimizer, criterion, args.patterns_per_exp,
- args.memory_strength, train_epochs=args.epochs,
- device=device, train_mb_size=10, evaluator=eval_plugin)
- elif args.strategy == 'agem':
- strategy = AGEM(model, optimizer, criterion, args.patterns_per_exp,
- args.sample_size, train_epochs=args.epochs,
- device=device, train_mb_size=10, evaluator=eval_plugin)
- else:
- raise ValueError("Wrong strategy name. Allowed gem, agem.")
- # train on the selected scenario with the chosen strategy
- print('Starting experiment...')
- results = []
- for experience in scenario.train_stream:
- print("Start training on experience ", experience.current_experience)
- strategy.train(experience)
- print("End training on experience ", experience.current_experience)
- print('Computing accuracy on the test set')
- results.append(strategy.eval(scenario.test_stream[:]))
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--strategy', type=str, choices=['gem', 'agem'],
- default='gem', help='Choose between GEM and A-GEM')
- parser.add_argument('--scenario', type=str,
- choices=['pmnist', 'smnist'], default='smnist',
- help='Choose between Permuted MNIST, Split MNIST.')
- parser.add_argument('--patterns_per_exp', type=int, default=256,
- help='Patterns to store in the memory for each'
- ' experience')
- parser.add_argument('--sample_size', type=int, default=256,
- help='Number of patterns to sample from memory when \
- projecting gradient. A-GEM only.')
- parser.add_argument('--memory_strength', type=float, default=0.5,
- help='Offset to add to the projection direction. '
- 'GEM only.')
- parser.add_argument('--lr', type=float, default=1e-1, help='Learning rate.')
- parser.add_argument('--hs', type=int, default=256, help='MLP hidden size.')
- parser.add_argument('--epochs', type=int, default=1,
- help='Number of training epochs.')
- parser.add_argument('--permutations', type=int, default=5,
- help='Number of experiences in Permuted MNIST.')
- parser.add_argument('--cuda', type=int, default=0,
- help='Specify GPU id to use. Use CPU if -1.')
- args = parser.parse_args()
- main(args)
|