123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520 |
- import pandas as pd
- from skimage import io
- import numpy as np
- from PIL import Image
- import torch
- import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
- import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
- import torchvision.transforms as transforms # Transformations we can perform on our dataset
- from torchvision.io import read_image
- import os
- from torch.utils.data import (
- Dataset,
- DataLoader,
- )
- from typing import Optional, Sequence, Union, List
- from avalanche.benchmarks.generators import nc_benchmark, ni_benchmark
- from avalanche.models import SimpleCNN
- from avalanche.models import pytorchcv_wrapper
- from avalanche.benchmarks.utils.datasets_from_filelists import SeqPathsDataset, PathsDataset
- from avalanche.training.strategies import Naive, Cumulative, JointTraining
- #os.environ['CUDA_LAUNCH_BLOCKING'] = "0"
- from avalanche.benchmarks.generators import paths_benchmark
- #from avalanche.training.plugins.sequence_data import _get_random_indicies, _get_cls_acc_based_indicies
- from avalanche.training.plugins import (SeqDataPlugin,
- ReplayPlugin,
- EvaluationPlugin,
- MaximallyInterferedRetrievalRehersalPlugin,
- ClassErrorRehersalPlugin,
- ClassErrorRehersalTemperaturePlugin,
- ClassFrequencyRehearsalPlugin,
- RandomRehersal,
- ClassBalancingReservoirMemoryRehersalPlugin,
- ReservoirMemoryRehearsalPlugin,
- WeightedeMovingClassErrorAverageRehersalPlugin,
- ClassErrorFrequencyAvgRehearsalPlugin,
- FillExpBasedRehearsalPlugin,
- FillExpOversampleBasedRehearsalPlugin)
- from avalanche.evaluation.metrics import (ClassTop_nAvgAcc,
- SeasonClassTop_nAcc,
- Top1AccTransfer,
- ClasswiseTop_nAcc,
- Top_nAcc,
- SeasonTop_nAcc,
- SeqMaxAcc,
- SeqClassTop_nAvgAcc,
- SeqClassAvgMaxAcc,
- SeasonSeqClassTop_nAcc,
- SeqClasswiseTop_nAcc,
- SeqAnyAcc,
- SeasonSeqTop_nAcc,
- SeqTop_nAcc,
- BinaryAnimalAcc,
- QuatrupleAnimalAcc,
- SeqBinaryAnimalAcc,
- IgnoreNonAnimalOutputsTop_nAcc,
- SeqIgnoreNonAnimalOutputsAcc,
- ClassIgnoreNonAnimalOutputsTop_nAvgAcc
- )
- from avalanche.logging import StrategyLogger, InteractiveLogger, TextLogger, TensorboardLogger, CSVLogger, GenericCSVLogger
- from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
- import matplotlib.pyplot as plt
- import pickle as pkl
- import time
- import argparse
- import threading
- from avalanche.models.utils import LabelSmoothingCrossEntropy
- from pathlib import Path
- import datetime
- import sys
- import random
- parser =argparse.ArgumentParser()
- parser.add_argument('--file', '-f', type=str, help='file name where train/test/val_stream.pkl should be loaded')
- parser.add_argument('--data_root', '-roots', type=str, default='/home/boehlke/AMMOD/continual_learning/data/boehlke_ssd_Species/', help='file name where train/test_stream.pkl should be loaded')
- parser.add_argument('--log_root', '-log', type=str, default='/home/boehlke/AMMOD/continual_learning/results_ever_eval_test/', help='file results shall be logged')
- parser.add_argument('--strategy', '-s', type=str, default='naive', help='strategy name')
- parser.add_argument('--epochs', '-e', type=int, default=3, help='number of times the finetuning set (rehearsal+experience) are itereted over')
- parser.add_argument('--batch_size', '-bs', type=int, default=48)
- parser.add_argument('--eval_after_n_exp', '-ne', type=int, default=1, help='interval of evaluation on validation set in experiences, in case of joint training, interval in epochs')
- parser.add_argument('--test_eval_after_n_exp', '-tne', type=int, default=100, help='interval of evaluation on validation set in experiences, in case of joint training, interval in epochs')
- parser.add_argument('--exp_size', '-es', type=int, default=128, help='exp size, also used as eval batch exp_size' )
- parser.add_argument('--memory_filling', '-memf', type=str, default='inf', help='how the memory shall be filled, default is inf, other options(cbrs_mem, stdrs_mem)')
- parser.add_argument('--memory_size', '-mems', type=float, default=0.3, help='either ratio of entire train stream saved in memory, or if larger than 1, number of instances possible in memory ')
- parser.add_argument('--temperature', '-temp', type=float, default=None, help='class error based reheasal with sharpening/softening temperature value')
- parser.add_argument('--rehearsal_method', '-rm', type=str, default=None, help='abreviation corresponding to one of the implemented rehersal methods such as "ce, cf, rr, mir, wce, cefa, fe, feo", see line 396')
- parser.add_argument('--nr_of_steps_to_avg', '-nexp_avg', type=int, default=None, help='needs to be set for WeightedeMovingClassErrorAverageRehersalPlugin and defines how many past exp are used in weighted average')
- parser.add_argument('--sigma', '-sig', type=float, default=None, help='needs to be set for WeightedeMovingClassErrorAverageRehersalPlugin, sets the rate of decay for weights in moving weighted average. The bigger sigma, the more past error rates are weitghed')
- parser.add_argument('--buffer_data_ratio', '-dr', type=float, default=8, help='how many images from memory should be used in the rehersal set for each image in the experience')
- parser.add_argument('--shuffle_stream', '-ss', action='store_true', default=False, help='if train_stream should be shuffled on an image level to obtain mode iid distribution in stream for comparison')
- parser.add_argument('--num_classes', '-nc', type=int, default=16, help='number of classes in dataset')
- parser.add_argument('--non_animal_cls', '-nac', default=[13,14,15], nargs='+', type=int, help='classes, that do not contain species, which is relevant for "binary" accuracies' )
- parser.add_argument('--images_root', '-ir', type=str, default='/home/AMMOD_data/camera_traps/BayerWald/G-Fallen/original/', help='directory where image files are stored')
- parser.add_argument('--label_dict', '-ld', type=str, default="/home/boehlke/AMMOD/cam_trap_classification/data/csv_files/BIRDS.pkl" , help='used to create confusion matrix at c-n intervals')
- args = parser.parse_args()
- print("PyTorch Version: ",torch.__version__)
- use_cuda = torch.cuda.is_available()
- global detected_device
- args.detected_device = torch.device("cuda:0" if use_cuda else "cpu")
- print(args.detected_device)
- number_workers = 8
- # Train and test data
- with open(args.data_root+args.file.replace('384','128')+'_train_stream.pkl', 'rb') as handle: # even for training with stream of exp size384 128 needs to be loaded first to match the winter stream files
- args.train_stream = pkl.load(handle)
- with open(args.data_root+args.file.replace('384','128')+'_test_stream.pkl', 'rb') as handle:
- args.test_stream = pkl.load(handle)#[:5000]
- with open(args.data_root+args.file.replace('384','128')+'_val_stream.pkl', 'rb') as handle:
- args.val_stream = pkl.load(handle)#[:5000]
- with open(args.data_root+args.file.replace('384','128')+'_winter_val_stream.pkl', 'rb') as handle:
- args.validation_stream_winter = pkl.load(handle)
- with open(args.data_root+args.file.replace('384','128')+'_winter_test_stream.pkl', 'rb') as handle:
- args.test_stream_winter = pkl.load(handle)
- with open(args.data_root+args.file.replace('_crop','').replace('384','128')+'_exp_season_split_dict.pkl', 'rb') as handle:
- args.exp_season_split_dict = pkl.load(handle)
- # Backbone model settings
- model_depth=18
- args.model = pytorchcv_wrapper.resnet('imagenet', depth=model_depth, pretrained=True)
- args.model.num_classes = args.num_classes
- num_input_ftrs = args.model.output.in_features
- args.model.output = nn.Linear(num_input_ftrs , args.model.num_classes)
- args.validation_stream_winter_files = np.array(args.validation_stream_winter)[:,0]
- args.validation_stream_winter_files = [args.images_root+s[0:] for s in args.validation_stream_winter_files ]
- args.test_stream_winter_files = np.array(args.test_stream_winter)[:,0]
- args.test_stream_winter_files = [args.images_root+s[0:] for s in args.test_stream_winter_files ]
- def flattened(list_w_sublists):
- flattened = []
- for item in list_w_sublists:
- if isinstance(item, list):
- for val in item:
- flattened.append(val)
- else:
- flattened.append(item)
- return flattened
- # creating the winter list from the loaded data
- winter_exp = args.exp_season_split_dict['winter']
- summer_exp = args.exp_season_split_dict['summer']
- winter_train_exp = [args.train_stream[i] for i in winter_exp]
- all_winter_train_data = [i for sublist in winter_train_exp for i in sublist]
- all_winter_train_files = np.array(all_winter_train_data)[:,0]
- all_winter_train_files = [args.images_root+s[0:] for s in all_winter_train_files ]
- all_test_files = np.array(args.test_stream)[:,0]
- all_test_files = [args.images_root+s[0:] for s in all_test_files ]
- all_val_files = np.array(args.validation_stream_winter)[:,0]
- all_val_files = [args.images_root+s[0:] for s in all_val_files ]
- args.all_winter_files = all_winter_train_files +args.test_stream_winter_files+args.validation_stream_winter_files
- # Optimizer Settings
- eps=1e-4
- weight_decay=75e-3
- args.optimizer = optim.AdamW(args.model.parameters(), eps=eps, weight_decay=weight_decay)
- args.criterion = LabelSmoothingCrossEntropy()
- with open(args.data_root+args.file+'_train_stream.pkl', 'rb') as handle:
- args.train_stream = pkl.load(handle)
- if args.shuffle_stream:
- train_exp_list_flatten = flattened(args.train_stream)
- random.shuffle(train_exp_list_flatten)
- shuffled_train_stream = []
- exp_size = len(args.train_stream[0])
- for i in range(len(args.train_stream)):
- shuffled_train_stream.append(train_exp_list_flatten[i*exp_size:(i+1)*exp_size])
- args.train_stream = shuffled_train_stream
- strategy = args.strategy
- # using parameters to assemble a name for the logging floder
- data_ratio =''
- rehersal=''
- if args.rehearsal_method is not None:
- strategy = 'rehersal'
- data_ratio = '_dataratio'+str(args.buffer_data_ratio)
- rehersal = '_'+args.rehearsal_method
- if ( args.rehearsal_method == 'ce' or args.rehearsal_method=='cf' ) and args.temperature is not None:
- rehersal = rehersal+'_temp'+str(args.temperature)
- if args.rehearsal_method == 'wce':
- if args.nr_of_steps_to_avg is None or args.sigma is None:
- print('ATTENTION: weighted moving average variables not set')
- print(args.nr_of_steps_to_avg)
- print(args.sigma)
- else:
- rehersal = rehersal+'_navg'+str(args.nr_of_steps_to_avg)+'_sigma'+str(args.sigma)
- if args.memory_filling !='inf':
- total_train_data = len(args.train_stream[0])*len(args.train_stream)
- print(args.memory_size)
- print(total_train_data)
- if args.memory_size <1:
- memory_size = int(args.memory_size*total_train_data)
- print('r', memory_size)
- else:
- memory_size = int(args.memory_size)
- print('i', memory_size)
- strategy = 'rehersal_' +args.memory_filling+'_memsize_'+str(args.memory_size)
- print('total memory size ', args.memory_size)
- else:
- print('inf_mem')
- print(args.memory_filling)
- if args.shuffle_stream:
- strategy = strategy+'_shuffled'
- # Logging rpaht
- log_root = Path('/home/boehlke/AMMOD/continual_learning/results/')
- if args.log_root is not None:
- log_root = Path(args.log_root)
- model_settings = 'resnet'+str(model_depth)+'_smoothloss_adamW_eps'+str(eps)+'wd'+str(weight_decay)+'_bs'+str(args.batch_size)
- strategy_settings = strategy+rehersal+data_ratio+'_ep'+str(args.epochs)+'_ne'+str(args.eval_after_n_exp)
- args.log_dir_name = log_root / model_settings/ args.file / strategy_settings
- print(args.log_dir_name)
- Path.mkdir(args.log_dir_name, parents=True, exist_ok=True)
- cls_in_teststream =np.unique(np.array(args.test_stream)[:,1])
- if cls_in_teststream.shape[0] != args.num_classes:
- print('ATTENTION: args.num_classes doesnt mathc number of classes in test_stream')
- print(args.num_classes)
- #exit(0)
- cls_in_valstream =np.unique(np.array(args.val_stream)[:,1])
- if cls_in_valstream.shape[0] != args.num_classes:
- print('ATTENTION: args.num_classes doesnt mathc number of classes in val_stream')
- print(args.num_classes)
- #exit(0)
- # Special case for
- with open(args.label_dict, 'rb') as p:
- label_dict = pkl.load(p)
- print(label_dict.items())
- if not 0 in label_dict.keys(): # when the class names arte the keys an the integer values are the values
- print('label_dict wrong way around')
- flip_lable_dict = {}
- for key,value in label_dict.items():
- flip_lable_dict[value] = key
- label_dict = flip_lable_dict
- # augmentations for train and val/test data
- resize = int(224/0.875)
- args.data_transforms = {
- 'train': transforms.Compose([
- transforms.Resize((resize,resize)),
- transforms.RandomCrop(size=(224,224)),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- #transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ]),
- 'val': transforms.Compose([
- transforms.Resize((224,224)),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ]),
- }
- def get_scenario(args):
- '''Returns a benchmark instance given a sequence of lists of files. A separate
- dataset will be created for each list. Each of those datasets
- will be considered a separate experience.
- function paths_benchmark() is defined in /avalanche/scenarios/generic_benchmark_creation.py
- under "create_generic_benchmark_from_paths()" which has been augmented from
- the original avalanche version with a path_dataset_class variable which allows for handling of data
- with more than path and label infomation. In this case, the sequence code is passed as well to allow
- for sequencewise evaluation.
- '''
-
- if args.strategy == 'joint':
- joint_train_stream = [item for sublist in args.train_stream for item in sublist ]
- args.train_stream = [joint_train_stream]
-
- task_labels = np.zeros((1,len(args.train_stream))).astype(int).tolist()[0]
- nac =torch.tensor(args.non_animal_cls, device=args.detected_device)
- nc = args.num_classes
- metrics = [
- #TrainingTime(reset_at='epoch', emit_at='epoch', mode='train'),
- #accuracy_metrics(minibatch=False, epoch=True, experience=False, stream=True),
- Top1AccTransfer(nr_train_exp=len(args.train_stream), reset_at='stream', emit_at='stream', mode='eval'),
- ClassTop_nAvgAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
- #ClassTop_nAvgAcc(nr_classes=nc, n=1,reset_at='epoch', emit_at='epoch', mode='train'),
- ClassTop_nAvgAcc(nr_classes=nc, n=2, reset_at='stream', emit_at='stream', mode='eval'),
- SeasonClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
- SeasonClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
- #ClassTop_nAvgAcc(nr_classes=nc, n=2,reset_at='epoch', emit_at='epoch', mode='train'),
- ClasswiseTop_nAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
- Top_nAcc(n=1, reset_at='stream', emit_at='stream', mode='eval'),
- SeasonTop_nAcc(n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
- SeasonTop_nAcc(n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
- #Top_nAcc(n=2, reset_at='stream', emit_at='stream', mode='eval'),
- #Top_nAcc(n=1, reset_at='epoch', emit_at='epoch', mode='train'),
- #Top_nAcc(n=2, reset_at='epoch', emit_at='epoch', mode='train'),
- SeqTop_nAcc(n=1, reset_at='stream', emit_at='stream', mode='eval'),
- SeqMaxAcc( reset_at='stream', emit_at='stream', mode='eval'),
- SeasonSeqTop_nAcc(num=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
- SeasonSeqTop_nAcc(num=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
- #SeqTop_nAcc(n=2, reset_at='stream', emit_at='stream', moall_winter_filesdall_winter_filese='eval'),
- #SeqTop_nAcc(n=3, reset_at='stream', emit_at='stream', mode='eval'),
- #SeqTop_nAcc(n=1, reset_at='epoch', emit_at='epoch', mode='train'),
- #SeqTop_nAcc(n=2, reset_at='epoch', emit_at='epoch', mode='train'),
- #SeqTop_nAcc(n=3, reset_at='epoch', emit_at='epoch', mode='train'),
- SeqClasswiseTop_nAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
- SeqClassTop_nAvgAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
- SeqClassAvgMaxAcc(nr_classes=nc, reset_at='stream', emit_at='stream', mode='eval'),
- SeasonSeqClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
- SeasonSeqClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
- #SeqClassTop_nAvgAcc(nr_classes=nc, n=1,reset_at='epoch', emit_at='epoch', mode='train'),
- #SeqClassTop_nAvgAcc(nr_classes=nc, n=2, reset_at='stream', emit_at='stream', mode='eval'),
- #SeqClassTop_nAvgAcc(nr_classes=nc, n=2,reset_at='epoch', emit_at='epoch', mode='train'),
- SeqAnyAcc(reset_at='stream', emit_at='stream', mode='eval'),
- #SeqAnyAcc(reset_at='epoch', emit_at='epoch', mode='train'),
- #BinaryAnimalAcc(non_animal_class =nac, reset_at='epoch', emit_at='epoch', mode='train'),
- #BinaryAnimalAcc(non_animal_class =nac,reset_at='stream', emit_at='stream', mode='eval'),
- #QuatrupleAnimalAcc(non_animal_class=nac, reset_at='epoch', emit_at='epoch', mode='train'),
- #QuatrupleAnimalAcc(non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
- #SeqBinaryAnimalAcc(non_animal_class=nac,reset_at='epoch', emit_at='epoch', mode='train'),
- #SeqBinaryAnimalAcc(non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
- #SeqIgnoreNonAnimalOutputsAcc(non_animal_class=nac,nr_classes=nc, reset_at='epoch', emit_at='epoch', mode='train'),
- #SeqIgnoreNonAnimalOutputsAcc(non_animal_class=nac,nr_classes=nc,reset_at='stream', emit_at='stream', mode='eval'),
- #ClassIgnoreNonAnimalOutputsTop_nAvgAcc(nr_classes=nc,n=1,non_animal_class=nac,reset_at='epoch', emit_at='epoch', mode='train'),
- #ClassIgnoreNonAnimalOutputsTop_nAvgAcc(nr_classes=nc,n=1, non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
- #IgnoreNonAnimalOutputsTop_nAcc(non_animal_class=nac, nr_classes=nc, n=1, reset_at='epoch', emit_at='epoch', mode='train'),
- #IgnoreNonAnimalOutputsTop_nAcc(non_animal_class=nac,nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
- #NrStepsBatchwise(reset_at='epoch', emit_at='epoch', mode='train'),
- #NrStepsImagewise(reset_at='epoch', emit_at='epoch', mode='train')
- ]
- metrics = flattened(metrics)
- datetime_stamp = str(datetime.datetime.now()).replace(' ','_').replace(':','-').replace('.','-')[:-4]
- tb_logger = TensorboardLogger(tb_log_dir=str(args.log_dir_name)+"/tb_data_"+datetime_stamp, filename_suffix='test_run')
- interactive_logger = InteractiveLogger()
- gen_csv_logger = GenericCSVLogger(log_folder=str(args.log_dir_name)+'/gen_csvlogs_'+datetime_stamp)
- args.combined_logger = EvaluationPlugin(
- metrics,
- loggers=[tb_logger, gen_csv_logger, interactive_logger],
- suppress_warnings=False)
- scenario = paths_benchmark(
- args.train_stream,
- [args.test_stream],
- other_streams_lists_of_files={'validation': [args.val_stream]},
- task_labels=task_labels,
- complete_test_set_only=True,
- train_transform=args.data_transforms['train'],
- eval_transform=args.data_transforms['val'],
- other_streams_transforms={'validation_stream': args.data_transforms['val']},
- path_dataset_class=SeqPathsDataset,
- common_root = args.images_root
- )
- return scenario
- def get_strategy(args):
- ''' Returns the avalanche training strategy with the necessary plugins '''
- plugin = SeqDataPlugin()
- plugins = [plugin]
- if args.memory_filling=='inf':
- if args.rehearsal_method == 'ce':
- if args.temperature is None:
- plugin_rehearsal = ClassErrorRehersalPlugin(buffer_data_ratio= args.buffer_data_ratio)
- plugins.append(plugin_rehearsal)
- else:
- plugin_rehearsal = ClassErrorRehersalTemperaturePlugin(buffer_data_ratio= args.buffer_data_ratio, temperature=args.temperature)
- plugins.append(plugin_rehearsal)
- elif args.rehearsal_method == 'cf':
- plugin_rehearsal = ClassFrequencyRehearsalPlugin(buffer_data_ratio = args.buffer_data_ratio)
- plugins.append(plugin_rehearsal)
- elif args.rehearsal_method == 'mir':
- plugin_rehearsal = MaximallyInterferedRetrievalRehersalPlugin(buffer_data_ratio=args.buffer_data_ratio)
- plugins.append(plugin_rehearsal)
- elif args.rehearsal_method == 'rr':
- plugin_rehearsal = RandomRehersal(buffer_data_ratio=args.buffer_data_ratio)
- plugins.append(plugin_rehearsal)
- elif args.rehearsal_method == 'wce':
- plugin_rehearsal = WeightedeMovingClassErrorAverageRehersalPlugin( buffer_data_ratio=args.buffer_data_ratio, nr_of_steps_to_avg=args.nr_of_steps_to_avg, nr_classes=args.num_classes, sigma=args.sigma)
- plugins.append(plugin_rehearsal)
- elif args.rehearsal_method == 'cefa':
- plugin_rehearsal = ClassErrorFrequencyAvgRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
- plugins.append(plugin_rehearsal)
- elif args.rehearsal_method == 'fe':
- plugin_rehearsal = FillExpBasedRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
- plugins.append(plugin_rehearsal)
- elif args.rehearsal_method == 'feo':
- plugin_rehearsal = FillExpOversampleBasedRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
- plugins.append(plugin_rehearsal)
- elif args.memory_filling == 'cbrs':
- plugin_rehearsal = ClassBalancingReservoirMemoryRehersalPlugin(buffer_data_ratio=args.buffer_data_ratio, memory_size=memory_size, nr_classes=args.num_classes, rehearsal_selection_strategy=args.rehearsal_method, temperature=args.temperature)
- plugins.append(plugin_rehearsal)
- elif args.memory_filling == 'stdrs':
- plugin_rehearsal = ReservoirMemoryRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, memory_size=memory_size, nr_classes=args.num_classes, rehearsal_selection_strategy=args.rehearsal_method, temperature=args.temperature)
- plugins.append(plugin_rehearsal)
- print('no rehearsal method selected')
- if args.strategy == 'joint':
- ''' In the joint baseline scenario the naive base strategy defined in /avalanche/training/strategies/base_strategies.py
- is used
- '''
- print('getting joint strategy')
- strategy = Naive(
- args.model,
- args.optimizer,
- args.criterion,
- train_mb_size=args.batch_size,
- train_epochs=1,
- eval_mb_size=args.exp_size,
- device=args.detected_device,
- plugins=[plugin],
- evaluator=args.combined_logger,
- label_dict=label_dict
- )
- elif args.strategy == 'cumulative':
- print('getting cumulative strategy')
- strategy = Cumulative(
- args.model,
- args.optimizer,
- args.criterion,
- train_mb_size=args.batch_size,
- train_epochs=args.epochs,
- eval_mb_size=args.exp_size,
- device=args.detected_device,
- plugins=[plugin],
- evaluator=args.combined_logger,
- label_dict=label_dict
- )
- else:
- print('getting naive strategy')
- strategy = Naive(
- args.model,
- args.optimizer,
- args.criterion,
- train_mb_size=args.batch_size,
- train_epochs=args.epochs,
- eval_mb_size=args.exp_size,
- device=args.detected_device,
- plugins=plugins,
- evaluator=args.combined_logger,
- label_dict=label_dict
- )
- return strategy
- scenario = get_scenario(args)
- cl_strategy = get_strategy(args)
- # different loops are necessary for different strategies.
- # when using class error based method, the evaluation on the validation data should occure more frequently, meaning the variable eval_after_n_exp (-ne) should be less than 10
- if args.strategy == 'joint':
- for i in range(args.epochs):
- #print('joint ', dir(cl_strategy) )
- cl_strategy.train(scenario.train_stream[0], num_workers=number_workers)
- print('Training completed')
- last_eval = i ==args.epochs-1
- if i%args.test_eval_after_n_exp==0 or last_eval:
- print('Computing accuracy on the whole test set')
- cl_strategy.eval(scenario.test_stream, num_workers=number_workers, last_eval=True)
- if i%args.eval_after_n_exp==0 or last_eval:
- print('Computing accuracy on the whole validation set')
- cl_strategy.eval(scenario.validation_stream, num_workers=number_workers, last_eval=True)
- if last_eval:
- cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=False)
- else:
- train_stream_len = len(scenario.train_stream)
- for i, experience in enumerate(scenario.train_stream):
- cl_strategy.train(experience, num_workers=number_workers)
- print('Training completed')
- last_eval = i ==train_stream_len-1
- if (i%args.eval_after_n_exp==0 or last_eval):
- print('Evaluaton after experience: ', i )
- #cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=True)
- print('Computing accuracy on the whole validation set')
- cl_strategy.eval(scenario.validation_stream, num_workers=number_workers, last_eval=True)
-
- if (i%args.test_eval_after_n_exp==0 or last_eval):
- print('Evaluaton after experience: ', i )
- print('Computing accuracy on the whole test set')
- cl_strategy.eval(scenario.test_stream, num_workers=number_workers, last_eval=True)
- if last_eval:
- cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=True)
|