1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- ################################################################################
- # Copyright (c) 2021 ContinualAI. #
- # Copyrights licensed under the MIT License. #
- # See the accompanying LICENSE file for terms. #
- # #
- # Date: 08-02-2021 #
- # Author(s): Lorenzo Pellegrini #
- # E-mail: contact@continualai.org #
- # Website: avalanche.continualai.org #
- ################################################################################
- """
- This is a simple example on how to use the AR1 strategy.
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import argparse
- import torch
- from torch.nn import CrossEntropyLoss
- from torchvision import transforms
- from torchvision.transforms import ToTensor, Resize
- from avalanche.benchmarks import SplitCIFAR10
- from avalanche.training.strategies.ar1 import AR1
- def main(args):
- # Device config
- device = torch.device(f"cuda:{args.cuda}"
- if torch.cuda.is_available() and
- args.cuda >= 0 else "cpu")
- # ---------
- # --- TRANSFORMATIONS
- train_transform = transforms.Compose([
- Resize(224),
- ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])
- test_transform = transforms.Compose([
- Resize(224),
- ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])
- # ---------
- # --- SCENARIO CREATION
- scenario = SplitCIFAR10(5, train_transform=train_transform,
- eval_transform=test_transform)
- # ---------
- # CREATE THE STRATEGY INSTANCE
- cl_strategy = AR1(criterion=CrossEntropyLoss(), device=device)
- # TRAINING LOOP
- print('Starting experiment...')
- results = []
- for experience in scenario.train_stream:
- print("Start of experience: ", experience.current_experience)
- print("Current Classes: ", experience.classes_in_this_experience)
- cl_strategy.train(experience, num_workers=0)
- print('Training completed')
- print('Computing accuracy on the whole test set')
- results.append(cl_strategy.eval(scenario.test_stream, num_workers=0))
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--cuda', type=int, default=0,
- help='Select zero-indexed cuda device. -1 to use CPU.')
- args = parser.parse_args()
- main(args)
|