import pandas as pd from skimage import io import numpy as np from PIL import Image import torch import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc. import torchvision.transforms as transforms # Transformations we can perform on our dataset from torchvision.io import read_image import os from torch.utils.data import ( Dataset, DataLoader, ) from typing import Optional, Sequence, Union, List from avalanche.benchmarks.generators import nc_benchmark, ni_benchmark from avalanche.models import SimpleCNN from avalanche.models import pytorchcv_wrapper from avalanche.benchmarks.utils.datasets_from_filelists import SeqPathsDataset, PathsDataset from avalanche.training.strategies import Naive, Cumulative, JointTraining #os.environ['CUDA_LAUNCH_BLOCKING'] = "0" from avalanche.benchmarks.generators import paths_benchmark #from avalanche.training.plugins.sequence_data import _get_random_indicies, _get_cls_acc_based_indicies from avalanche.training.plugins import (SeqDataPlugin, ReplayPlugin, EvaluationPlugin, MaximallyInterferedRetrievalRehersalPlugin, ClassErrorRehersalPlugin, ClassErrorRehersalTemperaturePlugin, ClassFrequencyRehearsalPlugin, RandomRehersal, ClassBalancingReservoirMemoryRehersalPlugin, ReservoirMemoryRehearsalPlugin, WeightedeMovingClassErrorAverageRehersalPlugin, ClassErrorFrequencyAvgRehearsalPlugin, FillExpBasedRehearsalPlugin, FillExpOversampleBasedRehearsalPlugin) from avalanche.evaluation.metrics import (ClassTop_nAvgAcc, SeasonClassTop_nAcc, Top1AccTransfer, ClasswiseTop_nAcc, Top_nAcc, SeasonTop_nAcc, SeqMaxAcc, SeqClassTop_nAvgAcc, SeqClassAvgMaxAcc, SeasonSeqClassTop_nAcc, SeqClasswiseTop_nAcc, SeqAnyAcc, SeasonSeqTop_nAcc, SeqTop_nAcc, BinaryAnimalAcc, QuatrupleAnimalAcc, SeqBinaryAnimalAcc, IgnoreNonAnimalOutputsTop_nAcc, SeqIgnoreNonAnimalOutputsAcc, ClassIgnoreNonAnimalOutputsTop_nAvgAcc ) from avalanche.logging import StrategyLogger, InteractiveLogger, TextLogger, TensorboardLogger, CSVLogger, GenericCSVLogger from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics import matplotlib.pyplot as plt import pickle as pkl import time import argparse import threading from avalanche.models.utils import LabelSmoothingCrossEntropy from pathlib import Path import datetime import sys import random parser =argparse.ArgumentParser() parser.add_argument('--file', '-f', type=str, help='file name where train/test/val_stream.pkl should be loaded') parser.add_argument('--data_root', '-roots', type=str, default='/home/boehlke/AMMOD/continual_learning/data/boehlke_ssd_Species/', help='file name where train/test_stream.pkl should be loaded') parser.add_argument('--log_root', '-log', type=str, default='/home/boehlke/AMMOD/continual_learning/results_ever_eval_test/', help='file results shall be logged') parser.add_argument('--strategy', '-s', type=str, default='naive', help='strategy name') parser.add_argument('--epochs', '-e', type=int, default=3, help='number of times the finetuning set (rehearsal+experience) are itereted over') parser.add_argument('--batch_size', '-bs', type=int, default=48) parser.add_argument('--eval_after_n_exp', '-ne', type=int, default=1, help='interval of evaluation on validation set in experiences, in case of joint training, interval in epochs') parser.add_argument('--test_eval_after_n_exp', '-tne', type=int, default=100, help='interval of evaluation on validation set in experiences, in case of joint training, interval in epochs') parser.add_argument('--exp_size', '-es', type=int, default=128, help='exp size, also used as eval batch exp_size' ) parser.add_argument('--memory_filling', '-memf', type=str, default='inf', help='how the memory shall be filled, default is inf, other options(cbrs_mem, stdrs_mem)') parser.add_argument('--memory_size', '-mems', type=float, default=0.3, help='either ratio of entire train stream saved in memory, or if larger than 1, number of instances possible in memory ') parser.add_argument('--temperature', '-temp', type=float, default=None, help='class error based reheasal with sharpening/softening temperature value') parser.add_argument('--rehearsal_method', '-rm', type=str, default=None, help='abreviation corresponding to one of the implemented rehersal methods such as "ce, cf, rr, mir, wce, cefa, fe, feo", see line 396') parser.add_argument('--nr_of_steps_to_avg', '-nexp_avg', type=int, default=None, help='needs to be set for WeightedeMovingClassErrorAverageRehersalPlugin and defines how many past exp are used in weighted average') parser.add_argument('--sigma', '-sig', type=float, default=None, help='needs to be set for WeightedeMovingClassErrorAverageRehersalPlugin, sets the rate of decay for weights in moving weighted average. The bigger sigma, the more past error rates are weitghed') parser.add_argument('--buffer_data_ratio', '-dr', type=float, default=8, help='how many images from memory should be used in the rehersal set for each image in the experience') parser.add_argument('--shuffle_stream', '-ss', action='store_true', default=False, help='if train_stream should be shuffled on an image level to obtain mode iid distribution in stream for comparison') parser.add_argument('--num_classes', '-nc', type=int, default=16, help='number of classes in dataset') parser.add_argument('--non_animal_cls', '-nac', default=[13,14,15], nargs='+', type=int, help='classes, that do not contain species, which is relevant for "binary" accuracies' ) parser.add_argument('--images_root', '-ir', type=str, default='/home/AMMOD_data/camera_traps/BayerWald/G-Fallen/original/', help='directory where image files are stored') parser.add_argument('--label_dict', '-ld', type=str, default="/home/boehlke/AMMOD/cam_trap_classification/data/csv_files/BIRDS.pkl" , help='used to create confusion matrix at c-n intervals') args = parser.parse_args() print("PyTorch Version: ",torch.__version__) use_cuda = torch.cuda.is_available() global detected_device args.detected_device = torch.device("cuda:0" if use_cuda else "cpu") print(args.detected_device) number_workers = 8 # Train and test data with open(args.data_root+args.file.replace('384','128')+'_train_stream.pkl', 'rb') as handle: # even for training with stream of exp size384 128 needs to be loaded first to match the winter stream files args.train_stream = pkl.load(handle) with open(args.data_root+args.file.replace('384','128')+'_test_stream.pkl', 'rb') as handle: args.test_stream = pkl.load(handle)#[:5000] with open(args.data_root+args.file.replace('384','128')+'_val_stream.pkl', 'rb') as handle: args.val_stream = pkl.load(handle)#[:5000] with open(args.data_root+args.file.replace('384','128')+'_winter_val_stream.pkl', 'rb') as handle: args.validation_stream_winter = pkl.load(handle) with open(args.data_root+args.file.replace('384','128')+'_winter_test_stream.pkl', 'rb') as handle: args.test_stream_winter = pkl.load(handle) with open(args.data_root+args.file.replace('_crop','').replace('384','128')+'_exp_season_split_dict.pkl', 'rb') as handle: args.exp_season_split_dict = pkl.load(handle) # Backbone model settings model_depth=18 args.model = pytorchcv_wrapper.resnet('imagenet', depth=model_depth, pretrained=True) args.model.num_classes = args.num_classes num_input_ftrs = args.model.output.in_features args.model.output = nn.Linear(num_input_ftrs , args.model.num_classes) args.validation_stream_winter_files = np.array(args.validation_stream_winter)[:,0] args.validation_stream_winter_files = [args.images_root+s[0:] for s in args.validation_stream_winter_files ] args.test_stream_winter_files = np.array(args.test_stream_winter)[:,0] args.test_stream_winter_files = [args.images_root+s[0:] for s in args.test_stream_winter_files ] def flattened(list_w_sublists): flattened = [] for item in list_w_sublists: if isinstance(item, list): for val in item: flattened.append(val) else: flattened.append(item) return flattened # creating the winter list from the loaded data winter_exp = args.exp_season_split_dict['winter'] summer_exp = args.exp_season_split_dict['summer'] winter_train_exp = [args.train_stream[i] for i in winter_exp] all_winter_train_data = [i for sublist in winter_train_exp for i in sublist] all_winter_train_files = np.array(all_winter_train_data)[:,0] all_winter_train_files = [args.images_root+s[0:] for s in all_winter_train_files ] all_test_files = np.array(args.test_stream)[:,0] all_test_files = [args.images_root+s[0:] for s in all_test_files ] all_val_files = np.array(args.validation_stream_winter)[:,0] all_val_files = [args.images_root+s[0:] for s in all_val_files ] args.all_winter_files = all_winter_train_files +args.test_stream_winter_files+args.validation_stream_winter_files # Optimizer Settings eps=1e-4 weight_decay=75e-3 args.optimizer = optim.AdamW(args.model.parameters(), eps=eps, weight_decay=weight_decay) args.criterion = LabelSmoothingCrossEntropy() with open(args.data_root+args.file+'_train_stream.pkl', 'rb') as handle: args.train_stream = pkl.load(handle) if args.shuffle_stream: train_exp_list_flatten = flattened(args.train_stream) random.shuffle(train_exp_list_flatten) shuffled_train_stream = [] exp_size = len(args.train_stream[0]) for i in range(len(args.train_stream)): shuffled_train_stream.append(train_exp_list_flatten[i*exp_size:(i+1)*exp_size]) args.train_stream = shuffled_train_stream strategy = args.strategy # using parameters to assemble a name for the logging floder data_ratio ='' rehersal='' if args.rehearsal_method is not None: strategy = 'rehersal' data_ratio = '_dataratio'+str(args.buffer_data_ratio) rehersal = '_'+args.rehearsal_method if ( args.rehearsal_method == 'ce' or args.rehearsal_method=='cf' ) and args.temperature is not None: rehersal = rehersal+'_temp'+str(args.temperature) if args.rehearsal_method == 'wce': if args.nr_of_steps_to_avg is None or args.sigma is None: print('ATTENTION: weighted moving average variables not set') print(args.nr_of_steps_to_avg) print(args.sigma) else: rehersal = rehersal+'_navg'+str(args.nr_of_steps_to_avg)+'_sigma'+str(args.sigma) if args.memory_filling !='inf': total_train_data = len(args.train_stream[0])*len(args.train_stream) print(args.memory_size) print(total_train_data) if args.memory_size <1: memory_size = int(args.memory_size*total_train_data) else: memory_size = int(args.memory_size) strategy = 'rehersal_' +args.memory_filling+'_memsize_'+str(args.memory_size) print('total memory size ', args.memory_size) else: print('inf_mem') print(args.memory_filling) if args.shuffle_stream: strategy = strategy+'_shuffled' # Logging rpaht log_root = Path('/home/boehlke/AMMOD/continual_learning/results/') if args.log_root is not None: log_root = Path(args.log_root) model_settings = 'resnet'+str(model_depth)+'_smoothloss_adamW_eps'+str(eps)+'wd'+str(weight_decay)+'_bs'+str(args.batch_size) strategy_settings = strategy+rehersal+data_ratio+'_ep'+str(args.epochs)+'_ne'+str(args.eval_after_n_exp) args.log_dir_name = log_root / model_settings/ args.file / strategy_settings print(args.log_dir_name) Path.mkdir(args.log_dir_name, parents=True, exist_ok=True) cls_in_teststream =np.unique(np.array(args.test_stream)[:,1]) if cls_in_teststream.shape[0] != args.num_classes: print('ATTENTION: args.num_classes doesnt mathc number of classes in test_stream') print(args.num_classes) #exit(0) cls_in_valstream =np.unique(np.array(args.val_stream)[:,1]) if cls_in_valstream.shape[0] != args.num_classes: print('ATTENTION: args.num_classes doesnt mathc number of classes in val_stream') print(args.num_classes) #exit(0) # Special case for with open(args.label_dict, 'rb') as p: label_dict = pkl.load(p) print(label_dict.items()) if not 0 in label_dict.keys(): # when the class names arte the keys an the integer values are the values print('label_dict wrong way around') flip_lable_dict = {} for key,value in label_dict.items(): flip_lable_dict[value] = key label_dict = flip_lable_dict # augmentations for train and val/test data resize = int(224/0.875) args.data_transforms = { 'train': transforms.Compose([ transforms.Resize((resize,resize)), transforms.RandomCrop(size=(224,224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), #transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } def get_scenario(args): '''Returns a benchmark instance given a sequence of lists of files. A separate dataset will be created for each list. Each of those datasets will be considered a separate experience. function paths_benchmark() is defined in /avalanche/scenarios/generic_benchmark_creation.py under "create_generic_benchmark_from_paths()" which has been augmented from the original avalanche version with a path_dataset_class variable which allows for handling of data with more than path and label infomation. In this case, the sequence code is passed as well to allow for sequencewise evaluation. ''' if args.strategy == 'joint': joint_train_stream = [item for sublist in args.train_stream for item in sublist ] args.train_stream = [joint_train_stream] task_labels = np.zeros((1,len(args.train_stream))).astype(int).tolist()[0] nac =torch.tensor(args.non_animal_cls, device=args.detected_device) nc = args.num_classes metrics = [ #TrainingTime(reset_at='epoch', emit_at='epoch', mode='train'), #accuracy_metrics(minibatch=False, epoch=True, experience=False, stream=True), Top1AccTransfer(nr_train_exp=len(args.train_stream), reset_at='stream', emit_at='stream', mode='eval'), ClassTop_nAvgAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'), #ClassTop_nAvgAcc(nr_classes=nc, n=1,reset_at='epoch', emit_at='epoch', mode='train'), ClassTop_nAvgAcc(nr_classes=nc, n=2, reset_at='stream', emit_at='stream', mode='eval'), SeasonClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'), SeasonClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'), #ClassTop_nAvgAcc(nr_classes=nc, n=2,reset_at='epoch', emit_at='epoch', mode='train'), ClasswiseTop_nAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'), Top_nAcc(n=1, reset_at='stream', emit_at='stream', mode='eval'), SeasonTop_nAcc(n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'), SeasonTop_nAcc(n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'), #Top_nAcc(n=2, reset_at='stream', emit_at='stream', mode='eval'), #Top_nAcc(n=1, reset_at='epoch', emit_at='epoch', mode='train'), #Top_nAcc(n=2, reset_at='epoch', emit_at='epoch', mode='train'), SeqTop_nAcc(n=1, reset_at='stream', emit_at='stream', mode='eval'), SeqMaxAcc( reset_at='stream', emit_at='stream', mode='eval'), SeasonSeqTop_nAcc(num=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'), SeasonSeqTop_nAcc(num=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'), #SeqTop_nAcc(n=2, reset_at='stream', emit_at='stream', moall_winter_filesdall_winter_filese='eval'), #SeqTop_nAcc(n=3, reset_at='stream', emit_at='stream', mode='eval'), #SeqTop_nAcc(n=1, reset_at='epoch', emit_at='epoch', mode='train'), #SeqTop_nAcc(n=2, reset_at='epoch', emit_at='epoch', mode='train'), #SeqTop_nAcc(n=3, reset_at='epoch', emit_at='epoch', mode='train'), SeqClasswiseTop_nAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'), SeqClassTop_nAvgAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'), SeqClassAvgMaxAcc(nr_classes=nc, reset_at='stream', emit_at='stream', mode='eval'), SeasonSeqClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'), SeasonSeqClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'), #SeqClassTop_nAvgAcc(nr_classes=nc, n=1,reset_at='epoch', emit_at='epoch', mode='train'), #SeqClassTop_nAvgAcc(nr_classes=nc, n=2, reset_at='stream', emit_at='stream', mode='eval'), #SeqClassTop_nAvgAcc(nr_classes=nc, n=2,reset_at='epoch', emit_at='epoch', mode='train'), SeqAnyAcc(reset_at='stream', emit_at='stream', mode='eval'), #SeqAnyAcc(reset_at='epoch', emit_at='epoch', mode='train'), #BinaryAnimalAcc(non_animal_class =nac, reset_at='epoch', emit_at='epoch', mode='train'), #BinaryAnimalAcc(non_animal_class =nac,reset_at='stream', emit_at='stream', mode='eval'), #QuatrupleAnimalAcc(non_animal_class=nac, reset_at='epoch', emit_at='epoch', mode='train'), #QuatrupleAnimalAcc(non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'), #SeqBinaryAnimalAcc(non_animal_class=nac,reset_at='epoch', emit_at='epoch', mode='train'), #SeqBinaryAnimalAcc(non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'), #SeqIgnoreNonAnimalOutputsAcc(non_animal_class=nac,nr_classes=nc, reset_at='epoch', emit_at='epoch', mode='train'), #SeqIgnoreNonAnimalOutputsAcc(non_animal_class=nac,nr_classes=nc,reset_at='stream', emit_at='stream', mode='eval'), #ClassIgnoreNonAnimalOutputsTop_nAvgAcc(nr_classes=nc,n=1,non_animal_class=nac,reset_at='epoch', emit_at='epoch', mode='train'), #ClassIgnoreNonAnimalOutputsTop_nAvgAcc(nr_classes=nc,n=1, non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'), #IgnoreNonAnimalOutputsTop_nAcc(non_animal_class=nac, nr_classes=nc, n=1, reset_at='epoch', emit_at='epoch', mode='train'), #IgnoreNonAnimalOutputsTop_nAcc(non_animal_class=nac,nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'), #NrStepsBatchwise(reset_at='epoch', emit_at='epoch', mode='train'), #NrStepsImagewise(reset_at='epoch', emit_at='epoch', mode='train') ] metrics = flattened(metrics) datetime_stamp = str(datetime.datetime.now()).replace(' ','_').replace(':','-').replace('.','-')[:-4] tb_logger = TensorboardLogger(tb_log_dir=str(args.log_dir_name)+"/tb_data_"+datetime_stamp, filename_suffix='test_run') interactive_logger = InteractiveLogger() gen_csv_logger = GenericCSVLogger(log_folder=str(args.log_dir_name)+'/gen_csvlogs_'+datetime_stamp) args.combined_logger = EvaluationPlugin( metrics, loggers=[tb_logger, gen_csv_logger, interactive_logger], suppress_warnings=False) scenario = paths_benchmark( args.train_stream, [args.test_stream], other_streams_lists_of_files={'validation': [args.val_stream]}, task_labels=task_labels, complete_test_set_only=True, train_transform=args.data_transforms['train'], eval_transform=args.data_transforms['val'], other_streams_transforms={'validation_stream': args.data_transforms['val']}, path_dataset_class=SeqPathsDataset, common_root = args.images_root ) return scenario def get_strategy(args): ''' Returns the avalanche training strategy with the necessary plugins ''' plugin = SeqDataPlugin() plugins = [plugin] if args.memory_filling=='inf': if args.rehearsal_method == 'ce': if args.temperature is None: plugin_rehearsal = ClassErrorRehersalPlugin(buffer_data_ratio= args.buffer_data_ratio) plugins.append(plugin_rehearsal) else: plugin_rehearsal = ClassErrorRehersalTemperaturePlugin(buffer_data_ratio= args.buffer_data_ratio, temperature=args.temperature) plugins.append(plugin_rehearsal) elif args.rehearsal_method == 'cf': plugin_rehearsal = ClassFrequencyRehearsalPlugin(buffer_data_ratio = args.buffer_data_ratio) plugins.append(plugin_rehearsal) elif args.rehearsal_method == 'mir': plugin_rehearsal = MaximallyInterferedRetrievalRehersalPlugin(buffer_data_ratio=args.buffer_data_ratio) plugins.append(plugin_rehearsal) elif args.rehearsal_method == 'rr': plugin_rehearsal = RandomRehersal(buffer_data_ratio=args.buffer_data_ratio) plugins.append(plugin_rehearsal) elif args.rehearsal_method == 'wce': plugin_rehearsal = WeightedeMovingClassErrorAverageRehersalPlugin( buffer_data_ratio=args.buffer_data_ratio, nr_of_steps_to_avg=args.nr_of_steps_to_avg, nr_classes=args.num_classes, sigma=args.sigma) plugins.append(plugin_rehearsal) elif args.rehearsal_method == 'cefa': plugin_rehearsal = ClassErrorFrequencyAvgRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes ) plugins.append(plugin_rehearsal) elif args.rehearsal_method == 'fe': plugin_rehearsal = FillExpBasedRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes ) plugins.append(plugin_rehearsal) elif args.rehearsal_method == 'feo': plugin_rehearsal = FillExpOversampleBasedRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes ) plugins.append(plugin_rehearsal) elif args.memory_filling == 'cbrs': plugin_rehearsal = ClassBalancingReservoirMemoryRehersalPlugin(buffer_data_ratio=args.buffer_data_ratio, memory_size=memory_size, nr_classes=args.num_classes, rehearsal_selection_strategy=args.rehearsal_method, temperature=args.temperature) plugins.append(plugin_rehearsal) elif args.memory_filling == 'stdrs': plugin_rehearsal = ReservoirMemoryRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, memory_size=memory_size, nr_classes=args.num_classes, rehearsal_selection_strategy=args.rehearsal_method, temperature=args.temperature) plugins.append(plugin_rehearsal) print('no rehearsal method selected') if args.strategy == 'joint': ''' In the joint baseline scenario the naive base strategy defined in /avalanche/training/strategies/base_strategies.py is used ''' print('getting joint strategy') strategy = Naive( args.model, args.optimizer, args.criterion, train_mb_size=args.batch_size, train_epochs=1, eval_mb_size=args.exp_size, device=args.detected_device, plugins=[plugin], evaluator=args.combined_logger, label_dict=label_dict ) elif args.strategy == 'cumulative': print('getting cumulative strategy') strategy = Cumulative( args.model, args.optimizer, args.criterion, train_mb_size=args.batch_size, train_epochs=args.epochs, eval_mb_size=args.exp_size, device=args.detected_device, plugins=[plugin], evaluator=args.combined_logger, label_dict=label_dict ) else: print('getting naive strategy') strategy = Naive( args.model, args.optimizer, args.criterion, train_mb_size=args.batch_size, train_epochs=args.epochs, eval_mb_size=args.exp_size, device=args.detected_device, plugins=plugins, evaluator=args.combined_logger, label_dict=label_dict ) return strategy scenario = get_scenario(args) cl_strategy = get_strategy(args) # different loops are necessary for different strategies. # when using class error based method, the evaluation on the validation data should occure more frequently, meaning the variable eval_after_n_exp (-ne) should be less than 10 if args.strategy == 'joint': for i in range(args.epochs): #print('joint ', dir(cl_strategy) ) cl_strategy.train(scenario.train_stream[0], num_workers=number_workers) print('Training completed') last_eval = i ==args.epochs-1 if i%args.test_eval_after_n_exp==0 or last_eval: print('Computing accuracy on the whole test set') cl_strategy.eval(scenario.test_stream, num_workers=number_workers, last_eval=True) if i%args.eval_after_n_exp==0 or last_eval: print('Computing accuracy on the whole validation set') cl_strategy.eval(scenario.validation_stream, num_workers=number_workers, last_eval=True) if last_eval: cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=False) else: train_stream_len = len(scenario.train_stream) for i, experience in enumerate(scenario.train_stream): cl_strategy.train(experience, num_workers=number_workers) print('Training completed') last_eval = i ==train_stream_len-1 if (i%args.eval_after_n_exp==0 or last_eval): print('Evaluaton after experience: ', i ) #cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=True) print('Computing accuracy on the whole validation set') cl_strategy.eval(scenario.validation_stream, num_workers=number_workers, last_eval=True) if (i%args.test_eval_after_n_exp==0 or last_eval): print('Evaluaton after experience: ', i ) print('Computing accuracy on the whole test set') cl_strategy.eval(scenario.test_stream, num_workers=number_workers, last_eval=True) if last_eval: cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=True)