123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- ################################################################################
- # Copyright (c) 2021 ContinualAI. #
- # Copyrights licensed under the MIT License. #
- # See the accompanying LICENSE file for terms. #
- # #
- # Date: 20-11-2020 #
- # Author(s): Vincenzo Lomonaco #
- # E-mail: contact@continualai.org #
- # Website: avalanche.continualai.org #
- ################################################################################
- """
- In this simple example we show all the different ways you can use MNIST with
- Avalanche.
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import torch
- import argparse
- from torch.nn import CrossEntropyLoss
- from torch.optim import SGD
- from avalanche.benchmarks.classic import PermutedMNIST, RotatedMNIST, \
- SplitMNIST
- from avalanche.models import SimpleMLP
- from avalanche.training.strategies import Naive
- def main(args):
- # Device config
- device = torch.device(f"cuda:{args.cuda}"
- if torch.cuda.is_available() and
- args.cuda >= 0 else "cpu")
- # model
- model = SimpleMLP(num_classes=10)
- # Here we show all the MNIST variation we offer in the "classic" benchmarks
- if args.mnist_type == 'permuted':
- scenario = PermutedMNIST(n_experiences=5, seed=1)
- elif args.mnist_type == 'rotated':
- scenario = RotatedMNIST(
- n_experiences=5, rotations_list=[30, 60, 90, 120, 150], seed=1)
- else:
- scenario = SplitMNIST(n_experiences=5, seed=1)
- # Than we can extract the parallel train and test streams
- train_stream = scenario.train_stream
- test_stream = scenario.test_stream
- # Prepare for training & testing
- optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
- criterion = CrossEntropyLoss()
- # Continual learning strategy with default logger
- cl_strategy = Naive(
- model, optimizer, criterion, train_mb_size=32, train_epochs=2,
- eval_mb_size=32, device=device)
- # train and test loop
- results = []
- for train_task in train_stream:
- print("Current Classes: ", train_task.classes_in_this_experience)
- cl_strategy.train(train_task)
- results.append(cl_strategy.eval(test_stream))
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--mnist_type', type=str, default='split',
- choices=['rotated', 'permuted', 'split'],
- help='Choose between MNIST variations: '
- 'rotated, permuted or split.')
- parser.add_argument('--cuda', type=int, default=0,
- help='Select zero-indexed cuda device. -1 to use CPU.')
- args = parser.parse_args()
- main(args)
|