|
@@ -0,0 +1,520 @@
|
|
|
+import pandas as pd
|
|
|
+from skimage import io
|
|
|
+import numpy as np
|
|
|
+from PIL import Image
|
|
|
+import torch
|
|
|
+import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
|
|
|
+import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
|
|
|
+import torchvision.transforms as transforms # Transformations we can perform on our dataset
|
|
|
+from torchvision.io import read_image
|
|
|
+import os
|
|
|
+from torch.utils.data import (
|
|
|
+ Dataset,
|
|
|
+ DataLoader,
|
|
|
+)
|
|
|
+from typing import Optional, Sequence, Union, List
|
|
|
+from avalanche.benchmarks.generators import nc_benchmark, ni_benchmark
|
|
|
+from avalanche.models import SimpleCNN
|
|
|
+from avalanche.models import pytorchcv_wrapper
|
|
|
+from avalanche.benchmarks.utils.datasets_from_filelists import SeqPathsDataset, PathsDataset
|
|
|
+from avalanche.training.strategies import Naive, Cumulative, JointTraining
|
|
|
+#os.environ['CUDA_LAUNCH_BLOCKING'] = "0"
|
|
|
+from avalanche.benchmarks.generators import paths_benchmark
|
|
|
+#from avalanche.training.plugins.sequence_data import _get_random_indicies, _get_cls_acc_based_indicies
|
|
|
+from avalanche.training.plugins import (SeqDataPlugin,
|
|
|
+ ReplayPlugin,
|
|
|
+ EvaluationPlugin,
|
|
|
+ MaximallyInterferedRetrievalRehersalPlugin,
|
|
|
+ ClassErrorRehersalPlugin,
|
|
|
+ ClassErrorRehersalTemperaturePlugin,
|
|
|
+ ClassFrequencyRehearsalPlugin,
|
|
|
+ RandomRehersal,
|
|
|
+ ClassBalancingReservoirMemoryRehersalPlugin,
|
|
|
+ ReservoirMemoryRehearsalPlugin,
|
|
|
+ WeightedeMovingClassErrorAverageRehersalPlugin,
|
|
|
+ ClassErrorFrequencyAvgRehearsalPlugin,
|
|
|
+ FillExpBasedRehearsalPlugin,
|
|
|
+ FillExpOversampleBasedRehearsalPlugin)
|
|
|
+from avalanche.evaluation.metrics import (ClassTop_nAvgAcc,
|
|
|
+ SeasonClassTop_nAcc,
|
|
|
+ Top1AccTransfer,
|
|
|
+ ClasswiseTop_nAcc,
|
|
|
+ Top_nAcc,
|
|
|
+ SeasonTop_nAcc,
|
|
|
+ SeqMaxAcc,
|
|
|
+ SeqClassTop_nAvgAcc,
|
|
|
+ SeqClassAvgMaxAcc,
|
|
|
+ SeasonSeqClassTop_nAcc,
|
|
|
+ SeqClasswiseTop_nAcc,
|
|
|
+ SeqAnyAcc,
|
|
|
+ SeasonSeqTop_nAcc,
|
|
|
+ SeqTop_nAcc,
|
|
|
+ BinaryAnimalAcc,
|
|
|
+ QuatrupleAnimalAcc,
|
|
|
+ SeqBinaryAnimalAcc,
|
|
|
+ IgnoreNonAnimalOutputsTop_nAcc,
|
|
|
+ SeqIgnoreNonAnimalOutputsAcc,
|
|
|
+ ClassIgnoreNonAnimalOutputsTop_nAvgAcc
|
|
|
+ )
|
|
|
+from avalanche.logging import StrategyLogger, InteractiveLogger, TextLogger, TensorboardLogger, CSVLogger, GenericCSVLogger
|
|
|
+from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import pickle as pkl
|
|
|
+import time
|
|
|
+import argparse
|
|
|
+import threading
|
|
|
+from avalanche.models.utils import LabelSmoothingCrossEntropy
|
|
|
+from pathlib import Path
|
|
|
+import datetime
|
|
|
+import sys
|
|
|
+import random
|
|
|
+
|
|
|
+parser =argparse.ArgumentParser()
|
|
|
+parser.add_argument('--file', '-f', type=str, help='file name where train/test/val_stream.pkl should be loaded')
|
|
|
+parser.add_argument('--data_root', '-roots', type=str, default='/home/boehlke/AMMOD/continual_learning/data/boehlke_ssd_Species/', help='file name where train/test_stream.pkl should be loaded')
|
|
|
+parser.add_argument('--log_root', '-log', type=str, default='/home/boehlke/AMMOD/continual_learning/results_ever_eval_test/', help='file results shall be logged')
|
|
|
+parser.add_argument('--strategy', '-s', type=str, default='naive', help='strategy name')
|
|
|
+parser.add_argument('--epochs', '-e', type=int, default=3, help='number of times the finetuning set (rehearsal+experience) are itereted over')
|
|
|
+parser.add_argument('--batch_size', '-bs', type=int, default=48)
|
|
|
+parser.add_argument('--eval_after_n_exp', '-ne', type=int, default=1, help='interval of evaluation on validation set in experiences, in case of joint training, interval in epochs')
|
|
|
+parser.add_argument('--test_eval_after_n_exp', '-tne', type=int, default=100, help='interval of evaluation on validation set in experiences, in case of joint training, interval in epochs')
|
|
|
+parser.add_argument('--exp_size', '-es', type=int, default=128, help='exp size, also used as eval batch exp_size' )
|
|
|
+parser.add_argument('--memory_filling', '-memf', type=str, default='inf', help='how the memory shall be filled, default is inf, other options(cbrs_mem, stdrs_mem)')
|
|
|
+parser.add_argument('--memory_size', '-mems', type=float, default=0.3, help='either ratio of entire train stream saved in memory, or if larger than 1, number of instances possible in memory ')
|
|
|
+parser.add_argument('--temperature', '-temp', type=float, default=None, help='class error based reheasal with sharpening/softening temperature value')
|
|
|
+parser.add_argument('--rehearsal_method', '-rm', type=str, default=None, help='abreviation corresponding to one of the implemented rehersal methods such as "ce, cf, rr, mir, wce, cefa, fe, feo", see line 396')
|
|
|
+parser.add_argument('--nr_of_steps_to_avg', '-nexp_avg', type=int, default=None, help='needs to be set for WeightedeMovingClassErrorAverageRehersalPlugin and defines how many past exp are used in weighted average')
|
|
|
+parser.add_argument('--sigma', '-sig', type=float, default=None, help='needs to be set for WeightedeMovingClassErrorAverageRehersalPlugin, sets the rate of decay for weights in moving weighted average. The bigger sigma, the more past error rates are weitghed')
|
|
|
+parser.add_argument('--buffer_data_ratio', '-dr', type=float, default=8, help='how many images from memory should be used in the rehersal set for each image in the experience')
|
|
|
+parser.add_argument('--shuffle_stream', '-ss', action='store_true', default=False, help='if train_stream should be shuffled on an image level to obtain mode iid distribution in stream for comparison')
|
|
|
+parser.add_argument('--num_classes', '-nc', type=int, default=16, help='number of classes in dataset')
|
|
|
+parser.add_argument('--non_animal_cls', '-nac', default=[13,14,15], nargs='+', type=int, help='classes, that do not contain species, which is relevant for "binary" accuracies' )
|
|
|
+parser.add_argument('--images_root', '-ir', type=str, default='/home/AMMOD_data/camera_traps/BayerWald/G-Fallen/original/', help='directory where image files are stored')
|
|
|
+parser.add_argument('--label_dict', '-ld', type=str, default="/home/boehlke/AMMOD/cam_trap_classification/data/csv_files/BIRDS.pkl" , help='used to create confusion matrix at c-n intervals')
|
|
|
+args = parser.parse_args()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+print("PyTorch Version: ",torch.__version__)
|
|
|
+use_cuda = torch.cuda.is_available()
|
|
|
+global detected_device
|
|
|
+args.detected_device = torch.device("cuda:0" if use_cuda else "cpu")
|
|
|
+print(args.detected_device)
|
|
|
+number_workers = 8
|
|
|
+
|
|
|
+# Train and test data
|
|
|
+with open(args.data_root+args.file.replace('384','128')+'_train_stream.pkl', 'rb') as handle: # even for training with stream of exp size384 128 needs to be loaded first to match the winter stream files
|
|
|
+ args.train_stream = pkl.load(handle)
|
|
|
+
|
|
|
+with open(args.data_root+args.file.replace('384','128')+'_test_stream.pkl', 'rb') as handle:
|
|
|
+ args.test_stream = pkl.load(handle)#[:5000]
|
|
|
+
|
|
|
+with open(args.data_root+args.file.replace('384','128')+'_val_stream.pkl', 'rb') as handle:
|
|
|
+ args.val_stream = pkl.load(handle)#[:5000]
|
|
|
+
|
|
|
+with open(args.data_root+args.file.replace('384','128')+'_winter_val_stream.pkl', 'rb') as handle:
|
|
|
+ args.validation_stream_winter = pkl.load(handle)
|
|
|
+
|
|
|
+with open(args.data_root+args.file.replace('384','128')+'_winter_test_stream.pkl', 'rb') as handle:
|
|
|
+ args.test_stream_winter = pkl.load(handle)
|
|
|
+
|
|
|
+with open(args.data_root+args.file.replace('_crop','').replace('384','128')+'_exp_season_split_dict.pkl', 'rb') as handle:
|
|
|
+ args.exp_season_split_dict = pkl.load(handle)
|
|
|
+
|
|
|
+
|
|
|
+# Backbone model settings
|
|
|
+model_depth=18
|
|
|
+args.model = pytorchcv_wrapper.resnet('imagenet', depth=model_depth, pretrained=True)
|
|
|
+args.model.num_classes = args.num_classes
|
|
|
+
|
|
|
+num_input_ftrs = args.model.output.in_features
|
|
|
+args.model.output = nn.Linear(num_input_ftrs , args.model.num_classes)
|
|
|
+args.validation_stream_winter_files = np.array(args.validation_stream_winter)[:,0]
|
|
|
+
|
|
|
+args.validation_stream_winter_files = [args.images_root+s[0:] for s in args.validation_stream_winter_files ]
|
|
|
+args.test_stream_winter_files = np.array(args.test_stream_winter)[:,0]
|
|
|
+args.test_stream_winter_files = [args.images_root+s[0:] for s in args.test_stream_winter_files ]
|
|
|
+
|
|
|
+def flattened(list_w_sublists):
|
|
|
+ flattened = []
|
|
|
+ for item in list_w_sublists:
|
|
|
+ if isinstance(item, list):
|
|
|
+ for val in item:
|
|
|
+ flattened.append(val)
|
|
|
+ else:
|
|
|
+ flattened.append(item)
|
|
|
+ return flattened
|
|
|
+
|
|
|
+# creating the winter list from the loaded data
|
|
|
+winter_exp = args.exp_season_split_dict['winter']
|
|
|
+summer_exp = args.exp_season_split_dict['summer']
|
|
|
+winter_train_exp = [args.train_stream[i] for i in winter_exp]
|
|
|
+all_winter_train_data = [i for sublist in winter_train_exp for i in sublist]
|
|
|
+all_winter_train_files = np.array(all_winter_train_data)[:,0]
|
|
|
+all_winter_train_files = [args.images_root+s[0:] for s in all_winter_train_files ]
|
|
|
+all_test_files = np.array(args.test_stream)[:,0]
|
|
|
+all_test_files = [args.images_root+s[0:] for s in all_test_files ]
|
|
|
+all_val_files = np.array(args.validation_stream_winter)[:,0]
|
|
|
+all_val_files = [args.images_root+s[0:] for s in all_val_files ]
|
|
|
+args.all_winter_files = all_winter_train_files +args.test_stream_winter_files+args.validation_stream_winter_files
|
|
|
+
|
|
|
+
|
|
|
+# Optimizer Settings
|
|
|
+eps=1e-4
|
|
|
+weight_decay=75e-3
|
|
|
+args.optimizer = optim.AdamW(args.model.parameters(), eps=eps, weight_decay=weight_decay)
|
|
|
+args.criterion = LabelSmoothingCrossEntropy()
|
|
|
+
|
|
|
+with open(args.data_root+args.file+'_train_stream.pkl', 'rb') as handle:
|
|
|
+ args.train_stream = pkl.load(handle)
|
|
|
+
|
|
|
+if args.shuffle_stream:
|
|
|
+ train_exp_list_flatten = flattened(args.train_stream)
|
|
|
+ random.shuffle(train_exp_list_flatten)
|
|
|
+ shuffled_train_stream = []
|
|
|
+ exp_size = len(args.train_stream[0])
|
|
|
+ for i in range(len(args.train_stream)):
|
|
|
+ shuffled_train_stream.append(train_exp_list_flatten[i*exp_size:(i+1)*exp_size])
|
|
|
+ args.train_stream = shuffled_train_stream
|
|
|
+strategy = args.strategy
|
|
|
+
|
|
|
+
|
|
|
+# using parameters to assemble a name for the logging floder
|
|
|
+data_ratio =''
|
|
|
+rehersal=''
|
|
|
+if args.rehearsal_method is not None:
|
|
|
+ strategy = 'rehersal'
|
|
|
+ data_ratio = '_dataratio'+str(args.buffer_data_ratio)
|
|
|
+ rehersal = '_'+args.rehearsal_method
|
|
|
+ if ( args.rehearsal_method == 'ce' or args.rehearsal_method=='cf' ) and args.temperature is not None:
|
|
|
+ rehersal = rehersal+'_temp'+str(args.temperature)
|
|
|
+ if args.rehearsal_method == 'wce':
|
|
|
+ if args.nr_of_steps_to_avg is None or args.sigma is None:
|
|
|
+ print('ATTENTION: weighted moving average variables not set')
|
|
|
+ print(args.nr_of_steps_to_avg)
|
|
|
+ print(args.sigma)
|
|
|
+ else:
|
|
|
+ rehersal = rehersal+'_navg'+str(args.nr_of_steps_to_avg)+'_sigma'+str(args.sigma)
|
|
|
+ if args.memory_filling !='inf':
|
|
|
+ total_train_data = len(args.train_stream[0])*len(args.train_stream)
|
|
|
+ print(args.memory_size)
|
|
|
+ print(total_train_data)
|
|
|
+ if args.memory_size <1:
|
|
|
+ memory_size = int(args.memory_size*total_train_data)
|
|
|
+ print('r', memory_size)
|
|
|
+ else:
|
|
|
+ memory_size = int(args.memory_size)
|
|
|
+ print('i', memory_size)
|
|
|
+ strategy = 'rehersal_' +args.memory_filling+'_memsize_'+str(args.memory_size)
|
|
|
+ print('total memory size ', args.memory_size)
|
|
|
+ else:
|
|
|
+ print('inf_mem')
|
|
|
+ print(args.memory_filling)
|
|
|
+if args.shuffle_stream:
|
|
|
+ strategy = strategy+'_shuffled'
|
|
|
+# Logging rpaht
|
|
|
+log_root = Path('/home/boehlke/AMMOD/continual_learning/results/')
|
|
|
+if args.log_root is not None:
|
|
|
+ log_root = Path(args.log_root)
|
|
|
+model_settings = 'resnet'+str(model_depth)+'_smoothloss_adamW_eps'+str(eps)+'wd'+str(weight_decay)+'_bs'+str(args.batch_size)
|
|
|
+strategy_settings = strategy+rehersal+data_ratio+'_ep'+str(args.epochs)+'_ne'+str(args.eval_after_n_exp)
|
|
|
+args.log_dir_name = log_root / model_settings/ args.file / strategy_settings
|
|
|
+print(args.log_dir_name)
|
|
|
+Path.mkdir(args.log_dir_name, parents=True, exist_ok=True)
|
|
|
+
|
|
|
+
|
|
|
+cls_in_teststream =np.unique(np.array(args.test_stream)[:,1])
|
|
|
+if cls_in_teststream.shape[0] != args.num_classes:
|
|
|
+ print('ATTENTION: args.num_classes doesnt mathc number of classes in test_stream')
|
|
|
+ print(args.num_classes)
|
|
|
+
|
|
|
+ #exit(0)
|
|
|
+
|
|
|
+cls_in_valstream =np.unique(np.array(args.val_stream)[:,1])
|
|
|
+if cls_in_valstream.shape[0] != args.num_classes:
|
|
|
+ print('ATTENTION: args.num_classes doesnt mathc number of classes in val_stream')
|
|
|
+ print(args.num_classes)
|
|
|
+ #exit(0)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# Special case for
|
|
|
+with open(args.label_dict, 'rb') as p:
|
|
|
+ label_dict = pkl.load(p)
|
|
|
+print(label_dict.items())
|
|
|
+
|
|
|
+if not 0 in label_dict.keys(): # when the class names arte the keys an the integer values are the values
|
|
|
+ print('label_dict wrong way around')
|
|
|
+ flip_lable_dict = {}
|
|
|
+ for key,value in label_dict.items():
|
|
|
+ flip_lable_dict[value] = key
|
|
|
+
|
|
|
+ label_dict = flip_lable_dict
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# augmentations for train and val/test data
|
|
|
+resize = int(224/0.875)
|
|
|
+args.data_transforms = {
|
|
|
+ 'train': transforms.Compose([
|
|
|
+ transforms.Resize((resize,resize)),
|
|
|
+ transforms.RandomCrop(size=(224,224)),
|
|
|
+ transforms.RandomHorizontalFlip(),
|
|
|
+ transforms.ToTensor(),
|
|
|
+ #transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
|
|
|
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
+ ]),
|
|
|
+ 'val': transforms.Compose([
|
|
|
+ transforms.Resize((224,224)),
|
|
|
+ transforms.ToTensor(),
|
|
|
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
+ ]),
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def get_scenario(args):
|
|
|
+
|
|
|
+ '''Returns a benchmark instance given a sequence of lists of files. A separate
|
|
|
+ dataset will be created for each list. Each of those datasets
|
|
|
+ will be considered a separate experience.
|
|
|
+
|
|
|
+ function paths_benchmark() is defined in /avalanche/scenarios/generic_benchmark_creation.py
|
|
|
+ under "create_generic_benchmark_from_paths()" which has been augmented from
|
|
|
+ the original avalanche version with a path_dataset_class variable which allows for handling of data
|
|
|
+ with more than path and label infomation. In this case, the sequence code is passed as well to allow
|
|
|
+ for sequencewise evaluation.
|
|
|
+ '''
|
|
|
+
|
|
|
+ if args.strategy == 'joint':
|
|
|
+ joint_train_stream = [item for sublist in args.train_stream for item in sublist ]
|
|
|
+ args.train_stream = [joint_train_stream]
|
|
|
+
|
|
|
+ task_labels = np.zeros((1,len(args.train_stream))).astype(int).tolist()[0]
|
|
|
+
|
|
|
+ nac =torch.tensor(args.non_animal_cls, device=args.detected_device)
|
|
|
+ nc = args.num_classes
|
|
|
+
|
|
|
+ metrics = [
|
|
|
+ #TrainingTime(reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #accuracy_metrics(minibatch=False, epoch=True, experience=False, stream=True),
|
|
|
+ Top1AccTransfer(nr_train_exp=len(args.train_stream), reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ ClassTop_nAvgAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #ClassTop_nAvgAcc(nr_classes=nc, n=1,reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ ClassTop_nAvgAcc(nr_classes=nc, n=2, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ SeasonClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ SeasonClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #ClassTop_nAvgAcc(nr_classes=nc, n=2,reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ ClasswiseTop_nAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ Top_nAcc(n=1, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ SeasonTop_nAcc(n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ SeasonTop_nAcc(n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #Top_nAcc(n=2, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #Top_nAcc(n=1, reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #Top_nAcc(n=2, reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ SeqTop_nAcc(n=1, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ SeqMaxAcc( reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ SeasonSeqTop_nAcc(num=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ SeasonSeqTop_nAcc(num=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #SeqTop_nAcc(n=2, reset_at='stream', emit_at='stream', moall_winter_filesdall_winter_filese='eval'),
|
|
|
+ #SeqTop_nAcc(n=3, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #SeqTop_nAcc(n=1, reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #SeqTop_nAcc(n=2, reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #SeqTop_nAcc(n=3, reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ SeqClasswiseTop_nAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ SeqClassTop_nAvgAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ SeqClassAvgMaxAcc(nr_classes=nc, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ SeasonSeqClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ SeasonSeqClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #SeqClassTop_nAvgAcc(nr_classes=nc, n=1,reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #SeqClassTop_nAvgAcc(nr_classes=nc, n=2, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #SeqClassTop_nAvgAcc(nr_classes=nc, n=2,reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ SeqAnyAcc(reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #SeqAnyAcc(reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #BinaryAnimalAcc(non_animal_class =nac, reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #BinaryAnimalAcc(non_animal_class =nac,reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #QuatrupleAnimalAcc(non_animal_class=nac, reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #QuatrupleAnimalAcc(non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #SeqBinaryAnimalAcc(non_animal_class=nac,reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #SeqBinaryAnimalAcc(non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #SeqIgnoreNonAnimalOutputsAcc(non_animal_class=nac,nr_classes=nc, reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #SeqIgnoreNonAnimalOutputsAcc(non_animal_class=nac,nr_classes=nc,reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #ClassIgnoreNonAnimalOutputsTop_nAvgAcc(nr_classes=nc,n=1,non_animal_class=nac,reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #ClassIgnoreNonAnimalOutputsTop_nAvgAcc(nr_classes=nc,n=1, non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #IgnoreNonAnimalOutputsTop_nAcc(non_animal_class=nac, nr_classes=nc, n=1, reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #IgnoreNonAnimalOutputsTop_nAcc(non_animal_class=nac,nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
|
|
|
+ #NrStepsBatchwise(reset_at='epoch', emit_at='epoch', mode='train'),
|
|
|
+ #NrStepsImagewise(reset_at='epoch', emit_at='epoch', mode='train')
|
|
|
+ ]
|
|
|
+
|
|
|
+
|
|
|
+ metrics = flattened(metrics)
|
|
|
+
|
|
|
+ datetime_stamp = str(datetime.datetime.now()).replace(' ','_').replace(':','-').replace('.','-')[:-4]
|
|
|
+ tb_logger = TensorboardLogger(tb_log_dir=str(args.log_dir_name)+"/tb_data_"+datetime_stamp, filename_suffix='test_run')
|
|
|
+ interactive_logger = InteractiveLogger()
|
|
|
+ gen_csv_logger = GenericCSVLogger(log_folder=str(args.log_dir_name)+'/gen_csvlogs_'+datetime_stamp)
|
|
|
+
|
|
|
+ args.combined_logger = EvaluationPlugin(
|
|
|
+ metrics,
|
|
|
+ loggers=[tb_logger, gen_csv_logger, interactive_logger],
|
|
|
+ suppress_warnings=False)
|
|
|
+
|
|
|
+
|
|
|
+ scenario = paths_benchmark(
|
|
|
+ args.train_stream,
|
|
|
+ [args.test_stream],
|
|
|
+ other_streams_lists_of_files={'validation': [args.val_stream]},
|
|
|
+ task_labels=task_labels,
|
|
|
+ complete_test_set_only=True,
|
|
|
+ train_transform=args.data_transforms['train'],
|
|
|
+ eval_transform=args.data_transforms['val'],
|
|
|
+ other_streams_transforms={'validation_stream': args.data_transforms['val']},
|
|
|
+ path_dataset_class=SeqPathsDataset,
|
|
|
+ common_root = args.images_root
|
|
|
+ )
|
|
|
+
|
|
|
+ return scenario
|
|
|
+
|
|
|
+def get_strategy(args):
|
|
|
+
|
|
|
+ ''' Returns the avalanche training strategy with the necessary plugins '''
|
|
|
+ plugin = SeqDataPlugin()
|
|
|
+ plugins = [plugin]
|
|
|
+
|
|
|
+ if args.memory_filling=='inf':
|
|
|
+ if args.rehearsal_method == 'ce':
|
|
|
+ if args.temperature is None:
|
|
|
+ plugin_rehearsal = ClassErrorRehersalPlugin(buffer_data_ratio= args.buffer_data_ratio)
|
|
|
+ plugins.append(plugin_rehearsal)
|
|
|
+ else:
|
|
|
+ plugin_rehearsal = ClassErrorRehersalTemperaturePlugin(buffer_data_ratio= args.buffer_data_ratio, temperature=args.temperature)
|
|
|
+ plugins.append(plugin_rehearsal)
|
|
|
+ elif args.rehearsal_method == 'cf':
|
|
|
+ plugin_rehearsal = ClassFrequencyRehearsalPlugin(buffer_data_ratio = args.buffer_data_ratio)
|
|
|
+ plugins.append(plugin_rehearsal)
|
|
|
+ elif args.rehearsal_method == 'mir':
|
|
|
+ plugin_rehearsal = MaximallyInterferedRetrievalRehersalPlugin(buffer_data_ratio=args.buffer_data_ratio)
|
|
|
+ plugins.append(plugin_rehearsal)
|
|
|
+ elif args.rehearsal_method == 'rr':
|
|
|
+ plugin_rehearsal = RandomRehersal(buffer_data_ratio=args.buffer_data_ratio)
|
|
|
+ plugins.append(plugin_rehearsal)
|
|
|
+ elif args.rehearsal_method == 'wce':
|
|
|
+ plugin_rehearsal = WeightedeMovingClassErrorAverageRehersalPlugin( buffer_data_ratio=args.buffer_data_ratio, nr_of_steps_to_avg=args.nr_of_steps_to_avg, nr_classes=args.num_classes, sigma=args.sigma)
|
|
|
+ plugins.append(plugin_rehearsal)
|
|
|
+ elif args.rehearsal_method == 'cefa':
|
|
|
+ plugin_rehearsal = ClassErrorFrequencyAvgRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
|
|
|
+ plugins.append(plugin_rehearsal)
|
|
|
+ elif args.rehearsal_method == 'fe':
|
|
|
+ plugin_rehearsal = FillExpBasedRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
|
|
|
+ plugins.append(plugin_rehearsal)
|
|
|
+ elif args.rehearsal_method == 'feo':
|
|
|
+ plugin_rehearsal = FillExpOversampleBasedRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
|
|
|
+ plugins.append(plugin_rehearsal)
|
|
|
+ elif args.memory_filling == 'cbrs':
|
|
|
+ plugin_rehearsal = ClassBalancingReservoirMemoryRehersalPlugin(buffer_data_ratio=args.buffer_data_ratio, memory_size=memory_size, nr_classes=args.num_classes, rehearsal_selection_strategy=args.rehearsal_method, temperature=args.temperature)
|
|
|
+ plugins.append(plugin_rehearsal)
|
|
|
+ elif args.memory_filling == 'stdrs':
|
|
|
+ plugin_rehearsal = ReservoirMemoryRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, memory_size=memory_size, nr_classes=args.num_classes, rehearsal_selection_strategy=args.rehearsal_method, temperature=args.temperature)
|
|
|
+ plugins.append(plugin_rehearsal)
|
|
|
+ print('no rehearsal method selected')
|
|
|
+ if args.strategy == 'joint':
|
|
|
+ ''' In the joint baseline scenario the naive base strategy defined in /avalanche/training/strategies/base_strategies.py
|
|
|
+ is used
|
|
|
+ '''
|
|
|
+ print('getting joint strategy')
|
|
|
+ strategy = Naive(
|
|
|
+ args.model,
|
|
|
+ args.optimizer,
|
|
|
+ args.criterion,
|
|
|
+ train_mb_size=args.batch_size,
|
|
|
+ train_epochs=1,
|
|
|
+ eval_mb_size=args.exp_size,
|
|
|
+ device=args.detected_device,
|
|
|
+ plugins=[plugin],
|
|
|
+ evaluator=args.combined_logger,
|
|
|
+ label_dict=label_dict
|
|
|
+ )
|
|
|
+
|
|
|
+ elif args.strategy == 'cumulative':
|
|
|
+ print('getting cumulative strategy')
|
|
|
+ strategy = Cumulative(
|
|
|
+ args.model,
|
|
|
+ args.optimizer,
|
|
|
+ args.criterion,
|
|
|
+ train_mb_size=args.batch_size,
|
|
|
+ train_epochs=args.epochs,
|
|
|
+ eval_mb_size=args.exp_size,
|
|
|
+ device=args.detected_device,
|
|
|
+ plugins=[plugin],
|
|
|
+ evaluator=args.combined_logger,
|
|
|
+ label_dict=label_dict
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ print('getting naive strategy')
|
|
|
+ strategy = Naive(
|
|
|
+ args.model,
|
|
|
+ args.optimizer,
|
|
|
+ args.criterion,
|
|
|
+ train_mb_size=args.batch_size,
|
|
|
+ train_epochs=args.epochs,
|
|
|
+ eval_mb_size=args.exp_size,
|
|
|
+ device=args.detected_device,
|
|
|
+ plugins=plugins,
|
|
|
+ evaluator=args.combined_logger,
|
|
|
+ label_dict=label_dict
|
|
|
+ )
|
|
|
+ return strategy
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+scenario = get_scenario(args)
|
|
|
+cl_strategy = get_strategy(args)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# different loops are necessary for different strategies.
|
|
|
+# when using class error based method, the evaluation on the validation data should occure more frequently, meaning the variable eval_after_n_exp (-ne) should be less than 10
|
|
|
+if args.strategy == 'joint':
|
|
|
+
|
|
|
+ for i in range(args.epochs):
|
|
|
+ #print('joint ', dir(cl_strategy) )
|
|
|
+ cl_strategy.train(scenario.train_stream[0], num_workers=number_workers)
|
|
|
+ print('Training completed')
|
|
|
+ last_eval = i ==args.epochs-1
|
|
|
+
|
|
|
+ if i%args.test_eval_after_n_exp==0 or last_eval:
|
|
|
+ print('Computing accuracy on the whole test set')
|
|
|
+ cl_strategy.eval(scenario.test_stream, num_workers=number_workers, last_eval=True)
|
|
|
+ if i%args.eval_after_n_exp==0 or last_eval:
|
|
|
+ print('Computing accuracy on the whole validation set')
|
|
|
+ cl_strategy.eval(scenario.validation_stream, num_workers=number_workers, last_eval=True)
|
|
|
+ if last_eval:
|
|
|
+ cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=False)
|
|
|
+
|
|
|
+
|
|
|
+else:
|
|
|
+ train_stream_len = len(scenario.train_stream)
|
|
|
+ for i, experience in enumerate(scenario.train_stream):
|
|
|
+ cl_strategy.train(experience, num_workers=number_workers)
|
|
|
+ print('Training completed')
|
|
|
+ last_eval = i ==train_stream_len-1
|
|
|
+ if (i%args.eval_after_n_exp==0 or last_eval):
|
|
|
+ print('Evaluaton after experience: ', i )
|
|
|
+ #cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=True)
|
|
|
+ print('Computing accuracy on the whole validation set')
|
|
|
+ cl_strategy.eval(scenario.validation_stream, num_workers=number_workers, last_eval=True)
|
|
|
+
|
|
|
+
|
|
|
+ if (i%args.test_eval_after_n_exp==0 or last_eval):
|
|
|
+ print('Evaluaton after experience: ', i )
|
|
|
+ print('Computing accuracy on the whole test set')
|
|
|
+ cl_strategy.eval(scenario.test_stream, num_workers=number_workers, last_eval=True)
|
|
|
+
|
|
|
+ if last_eval:
|
|
|
+
|
|
|
+ cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=True)
|