1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- ################################################################################
- # Copyright (c) 2021 ContinualAI. #
- # Copyrights licensed under the MIT License. #
- # See the accompanying LICENSE file for terms. #
- # #
- # Date: 01-12-2020 #
- # Author(s): Andrea Cossu #
- # E-mail: contact@continualai.org #
- # Website: avalanche.continualai.org #
- ################################################################################
- """
- This example trains on Split CIFAR10 with Naive strategy.
- In this example each experience has a different task label.
- The task label, although available, is not used at test time.
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import argparse
- import torch
- from torch.nn import CrossEntropyLoss
- from torch.optim import Adam
- from avalanche.benchmarks.classic import SplitCIFAR10
- from avalanche.models import SimpleMLP
- from avalanche.training.strategies import Naive
- def main(args):
- # Config
- device = torch.device(f"cuda:{args.cuda}"
- if torch.cuda.is_available() and
- args.cuda >= 0 else "cpu")
- # model
- model = SimpleMLP(input_size=32*32*3, num_classes=10)
- # CL Benchmark Creation
- scenario = SplitCIFAR10(n_experiences=5, return_task_id=True)
- train_stream = scenario.train_stream
- test_stream = scenario.test_stream
- # Prepare for training & testing
- optimizer = Adam(model.parameters(), lr=0.01)
- criterion = CrossEntropyLoss()
- # Choose a CL strategy
- strategy = Naive(
- model=model, optimizer=optimizer, criterion=criterion,
- train_mb_size=128, train_epochs=3, eval_mb_size=128, device=device)
- # train and test loop
- for train_task in train_stream:
- strategy.train(train_task, num_workers=0)
- strategy.eval(test_stream)
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--cuda', type=int, default=0,
- help='Select zero-indexed cuda device. -1 to use CPU.')
- args = parser.parse_args()
- main(args)
|