123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import torch
- from os.path import expanduser
- import argparse
- from torchvision.datasets import MNIST
- from torchvision.transforms import ToTensor
- from avalanche.benchmarks import PermutedMNIST, nc_benchmark
- from avalanche.training.strategies import EWC
- from avalanche.models import SimpleMLP
- from avalanche.evaluation.metrics import forgetting_metrics, \
- accuracy_metrics, loss_metrics, bwt_metrics
- from avalanche.logging import InteractiveLogger, TensorboardLogger
- from avalanche.training.plugins import EvaluationPlugin
- """
- This example tests EWC on Split MNIST and Permuted MNIST.
- It is possible to choose, among other options, between EWC with separate
- penalties and online EWC with a single penalty.
- On Permuted MNIST EWC maintains a very good performance on previous tasks
- with a wide range of configurations. The average accuracy on previous tasks
- at the end of training on all task is around 85%,
- with a comparable training accuracy.
- On Split MNIST, on the contrary, EWC is not able to remember previous tasks and
- is subjected to complete forgetting in all configurations. The training accuracy
- is above 90% but the average accuracy on previou tasks is around 20%.
- """
- def main(args):
- model = SimpleMLP(hidden_size=args.hs)
- optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
- criterion = torch.nn.CrossEntropyLoss()
- # check if selected GPU is available or use CPU
- assert args.cuda == -1 or args.cuda >= 0, "cuda must be -1 or >= 0."
- device = torch.device(f"cuda:{args.cuda}"
- if torch.cuda.is_available() and
- args.cuda >= 0 else "cpu")
- print(f'Using device: {device}')
- # create scenario
- if args.scenario == 'pmnist':
- scenario = PermutedMNIST(n_experiences=args.permutations)
- elif args.scenario == 'smnist':
- mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=True, download=True, transform=ToTensor())
- mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
- train=False, download=True, transform=ToTensor())
- scenario = nc_benchmark(
- mnist_train, mnist_test, 5, task_labels=False, seed=1234)
- else:
- raise ValueError("Wrong scenario name. Allowed pmnist, smnist.")
- # choose some metrics and evaluation method
- interactive_logger = InteractiveLogger()
- tensorboard_logger = TensorboardLogger()
- eval_plugin = EvaluationPlugin(
- accuracy_metrics(
- minibatch=True, epoch=True, experience=True, stream=True),
- loss_metrics(
- minibatch=True, epoch=True, experience=True, stream=True),
- forgetting_metrics(experience=True, stream=True),
- bwt_metrics(experience=True, stream=True),
- loggers=[interactive_logger, tensorboard_logger])
- # create strategy
- strategy = EWC(model, optimizer, criterion, args.ewc_lambda,
- args.ewc_mode, decay_factor=args.decay_factor,
- train_epochs=args.epochs, device=device,
- train_mb_size=args.minibatch_size, evaluator=eval_plugin)
- # train on the selected scenario with the chosen strategy
- print('Starting experiment...')
- results = []
- for experience in scenario.train_stream:
- print("Start training on experience ", experience.current_experience)
- strategy.train(experience)
- print("End training on experience", experience.current_experience)
- print('Computing accuracy on the test set')
- results.append(strategy.eval(scenario.test_stream[:]))
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--scenario', type=str,
- choices=['pmnist', 'smnist'], default='smnist',
- help='Choose between Permuted MNIST, Split MNIST.')
- parser.add_argument('--ewc_mode', type=str, choices=['separate', 'online'],
- default='separate',
- help='Choose between EWC and online.')
- parser.add_argument('--ewc_lambda', type=float, default=0.4,
- help='Penalty hyperparameter for EWC')
- parser.add_argument('--decay_factor', type=float, default=0.1,
- help='Decay factor for importance '
- 'when ewc_mode is online.')
- parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.')
- parser.add_argument('--hs', type=int, default=256, help='MLP hidden size.')
- parser.add_argument('--epochs', type=int, default=10,
- help='Number of training epochs.')
- parser.add_argument('--minibatch_size', type=int, default=128,
- help='Minibatch size.')
- parser.add_argument('--permutations', type=int, default=5,
- help='Number of experiences in Permuted MNIST.')
- parser.add_argument('--cuda', type=int, default=0,
- help='Specify GPU id to use. Use CPU if -1.')
- args = parser.parse_args()
- main(args)
|