Просмотр исходного кода

bringing it all together in a unified script

Julia Boehlke 2 лет назад
Родитель
Сommit
01e36b8cf6

+ 2 - 1
README.md

@@ -2,7 +2,8 @@ This Repository contains a folder called avalanche which is basically a fork fro
 
 https://github.com/ContinualAI/avalanche
 
-Further the camera trap data and the way it was handled is described. The corresponding jupyter notbook is found in scripts/jupyter-notebooks
+The camera trap data and the way it was handled is described. The corresponding jupyter notbook is found in scripts/jupyter-notebooks/data_stream_creation.ipynb
+The scripts/jupyter-notebooks/introduction_to_avalanche.ipynb was written to explain the avalanche framework and the extensions I made. 
 
 After cloning the package, create a new envionrment with python 3.8:
 

+ 2 - 2
avalanche/avalanche/training/plugins/class_balancing_memory.py

@@ -213,15 +213,14 @@ class ReservoirMemoryRehearsalPlugin(ClassImbalanceMemoryRehersalPlugin):
 		
 		self.nr_classes = nr_classes
 		self.total_stream_length = 0
-		
 
 
 	def after_training_exp(self, strategy: 'BaseStrategy', **kwargs):
-
 		if self.memory_dataset is None:
 			self.memory_dataset = strategy.experience.dataset
 			self.total_stream_length = len(strategy.experience.dataset)
 
+
 		elif len(self.memory_dataset)+len(strategy.experience.dataset) <= self.memory_size:
 			combined_paths = self.memory_dataset.paths+strategy.experience.dataset.paths
 
@@ -230,6 +229,7 @@ class ReservoirMemoryRehearsalPlugin(ClassImbalanceMemoryRehersalPlugin):
 			self.total_stream_length += len(strategy.experience.dataset)
 
 		else:
+
 			if len(self.memory_dataset)==self.memory_size:
 				experience_data = strategy.experience.dataset
 

+ 1 - 2
avalanche/avalanche/training/strategies/base_strategy.py

@@ -214,8 +214,7 @@ class BaseStrategy:
             self.total_test_cls_dict[i] = 0
             self.total_validation_cls_dict[i] = 0
 
-        # variables to track class wise data occurances
-
+        # variables to track class wise data occurances 
 
         self.label_dict = label_dict 
         #""" Dictionary with int-labels as keys and class names as values """

+ 520 - 0
scripts/cont_ava.py

@@ -0,0 +1,520 @@
+import pandas as pd
+from skimage import io
+import numpy as np
+from PIL import Image
+import torch
+import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
+import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
+import torchvision.transforms as transforms  # Transformations we can perform on our dataset
+from torchvision.io import read_image
+import os
+from torch.utils.data import (
+    Dataset,
+    DataLoader,
+)  
+from typing import Optional, Sequence, Union, List
+from avalanche.benchmarks.generators import nc_benchmark, ni_benchmark
+from avalanche.models import SimpleCNN 
+from avalanche.models import pytorchcv_wrapper
+from avalanche.benchmarks.utils.datasets_from_filelists import SeqPathsDataset, PathsDataset
+from avalanche.training.strategies import Naive, Cumulative, JointTraining
+#os.environ['CUDA_LAUNCH_BLOCKING'] = "0"
+from avalanche.benchmarks.generators import paths_benchmark
+#from avalanche.training.plugins.sequence_data import _get_random_indicies, _get_cls_acc_based_indicies
+from avalanche.training.plugins import (SeqDataPlugin, 
+                                        ReplayPlugin, 
+                                        EvaluationPlugin, 
+                                        MaximallyInterferedRetrievalRehersalPlugin, 
+                                        ClassErrorRehersalPlugin,
+                                        ClassErrorRehersalTemperaturePlugin,
+                                        ClassFrequencyRehearsalPlugin,
+                                        RandomRehersal,
+                                        ClassBalancingReservoirMemoryRehersalPlugin,
+                                        ReservoirMemoryRehearsalPlugin,
+                                        WeightedeMovingClassErrorAverageRehersalPlugin,
+                                        ClassErrorFrequencyAvgRehearsalPlugin,
+                                        FillExpBasedRehearsalPlugin,
+                                        FillExpOversampleBasedRehearsalPlugin)
+from avalanche.evaluation.metrics import (ClassTop_nAvgAcc,  
+                                          SeasonClassTop_nAcc,
+                                          Top1AccTransfer,
+                                          ClasswiseTop_nAcc, 
+                                          Top_nAcc,
+                                          SeasonTop_nAcc,
+                                          SeqMaxAcc,
+                                          SeqClassTop_nAvgAcc,
+                                          SeqClassAvgMaxAcc,
+                                          SeasonSeqClassTop_nAcc,
+                                          SeqClasswiseTop_nAcc,
+                                          SeqAnyAcc,
+                                          SeasonSeqTop_nAcc,
+                                          SeqTop_nAcc,
+                                          BinaryAnimalAcc,
+                                          QuatrupleAnimalAcc,
+                                          SeqBinaryAnimalAcc,
+                                          IgnoreNonAnimalOutputsTop_nAcc,
+                                          SeqIgnoreNonAnimalOutputsAcc,
+                                          ClassIgnoreNonAnimalOutputsTop_nAvgAcc
+                                         )
+from avalanche.logging import StrategyLogger, InteractiveLogger, TextLogger, TensorboardLogger, CSVLogger, GenericCSVLogger
+from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
+import matplotlib.pyplot as plt
+import pickle as pkl
+import time
+import argparse
+import threading
+from avalanche.models.utils import LabelSmoothingCrossEntropy
+from pathlib import Path
+import datetime
+import sys  
+import random
+
+parser =argparse.ArgumentParser()
+parser.add_argument('--file', '-f', type=str, help='file name where train/test/val_stream.pkl should be loaded')
+parser.add_argument('--data_root', '-roots', type=str, default='/home/boehlke/AMMOD/continual_learning/data/boehlke_ssd_Species/', help='file name where train/test_stream.pkl should be loaded')
+parser.add_argument('--log_root', '-log', type=str, default='/home/boehlke/AMMOD/continual_learning/results_ever_eval_test/', help='file results shall be logged')
+parser.add_argument('--strategy', '-s', type=str, default='naive', help='strategy name')
+parser.add_argument('--epochs', '-e', type=int, default=3, help='number of times the finetuning set (rehearsal+experience) are itereted over')
+parser.add_argument('--batch_size', '-bs', type=int, default=48)
+parser.add_argument('--eval_after_n_exp', '-ne', type=int, default=1, help='interval of evaluation on validation set in experiences, in case of joint training, interval in epochs')
+parser.add_argument('--test_eval_after_n_exp', '-tne', type=int, default=100, help='interval of evaluation on validation set in experiences, in case of joint training, interval in epochs')
+parser.add_argument('--exp_size', '-es', type=int, default=128, help='exp size, also used as eval batch exp_size' )
+parser.add_argument('--memory_filling', '-memf', type=str, default='inf', help='how the memory shall be filled, default is inf, other options(cbrs_mem, stdrs_mem)')
+parser.add_argument('--memory_size', '-mems', type=float, default=0.3, help='either ratio of entire train stream saved in memory, or if larger than 1, number of instances possible in memory ')
+parser.add_argument('--temperature', '-temp', type=float, default=None, help='class error based reheasal with sharpening/softening temperature value')
+parser.add_argument('--rehearsal_method', '-rm', type=str, default=None, help='abreviation corresponding to one of the implemented rehersal methods such as "ce, cf, rr, mir, wce, cefa, fe, feo", see line 396')
+parser.add_argument('--nr_of_steps_to_avg', '-nexp_avg', type=int, default=None, help='needs to be set for WeightedeMovingClassErrorAverageRehersalPlugin and defines how many past exp are used in weighted average')
+parser.add_argument('--sigma', '-sig', type=float, default=None, help='needs to be set for WeightedeMovingClassErrorAverageRehersalPlugin, sets the rate of decay for weights in moving weighted average. The bigger sigma, the more past error rates are weitghed')
+parser.add_argument('--buffer_data_ratio', '-dr', type=float, default=8, help='how many images from memory should be used in the rehersal set for each image in the experience')
+parser.add_argument('--shuffle_stream', '-ss', action='store_true', default=False, help='if train_stream should be shuffled on an image level to obtain mode iid distribution in stream for comparison')
+parser.add_argument('--num_classes', '-nc', type=int, default=16, help='number of classes in dataset')
+parser.add_argument('--non_animal_cls', '-nac', default=[13,14,15], nargs='+', type=int, help='classes, that do not contain species, which is relevant for "binary" accuracies' )
+parser.add_argument('--images_root', '-ir', type=str, default='/home/AMMOD_data/camera_traps/BayerWald/G-Fallen/original/', help='directory where image files are stored')
+parser.add_argument('--label_dict', '-ld', type=str, default="/home/boehlke/AMMOD/cam_trap_classification/data/csv_files/BIRDS.pkl" , help='used to create confusion matrix at c-n intervals')
+args = parser.parse_args()
+
+
+
+print("PyTorch Version: ",torch.__version__)
+use_cuda = torch.cuda.is_available()
+global detected_device
+args.detected_device = torch.device("cuda:0" if use_cuda else "cpu")
+print(args.detected_device)
+number_workers = 8
+
+# Train and test data
+with open(args.data_root+args.file.replace('384','128')+'_train_stream.pkl', 'rb') as handle: # even for training with stream of exp size384 128 needs to be loaded first to match the winter stream files
+    args.train_stream = pkl.load(handle)
+
+with open(args.data_root+args.file.replace('384','128')+'_test_stream.pkl', 'rb') as handle:
+    args.test_stream = pkl.load(handle)#[:5000]
+
+with open(args.data_root+args.file.replace('384','128')+'_val_stream.pkl', 'rb') as handle:
+    args.val_stream = pkl.load(handle)#[:5000]
+
+with open(args.data_root+args.file.replace('384','128')+'_winter_val_stream.pkl', 'rb') as handle:
+    args.validation_stream_winter = pkl.load(handle)
+
+with open(args.data_root+args.file.replace('384','128')+'_winter_test_stream.pkl', 'rb') as handle:
+    args.test_stream_winter = pkl.load(handle)
+
+with open(args.data_root+args.file.replace('_crop','').replace('384','128')+'_exp_season_split_dict.pkl', 'rb') as handle:
+    args.exp_season_split_dict = pkl.load(handle)
+
+
+# Backbone model settings
+model_depth=18
+args.model = pytorchcv_wrapper.resnet('imagenet', depth=model_depth, pretrained=True)
+args.model.num_classes = args.num_classes
+
+num_input_ftrs = args.model.output.in_features
+args.model.output = nn.Linear(num_input_ftrs , args.model.num_classes)
+args.validation_stream_winter_files = np.array(args.validation_stream_winter)[:,0]
+
+args.validation_stream_winter_files = [args.images_root+s[0:] for s in args.validation_stream_winter_files ]
+args.test_stream_winter_files = np.array(args.test_stream_winter)[:,0]
+args.test_stream_winter_files = [args.images_root+s[0:] for s in args.test_stream_winter_files ]
+
+def flattened(list_w_sublists):
+    flattened = []
+    for item in list_w_sublists:
+        if  isinstance(item, list):
+            for val in item:
+                flattened.append(val)
+        else:
+            flattened.append(item)
+    return flattened
+
+# creating the winter list from the loaded data
+winter_exp = args.exp_season_split_dict['winter']
+summer_exp = args.exp_season_split_dict['summer']
+winter_train_exp = [args.train_stream[i] for i in winter_exp]
+all_winter_train_data = [i for sublist in winter_train_exp for i in sublist]
+all_winter_train_files = np.array(all_winter_train_data)[:,0]
+all_winter_train_files = [args.images_root+s[0:] for s in all_winter_train_files ]
+all_test_files = np.array(args.test_stream)[:,0]
+all_test_files = [args.images_root+s[0:] for s in all_test_files ]
+all_val_files = np.array(args.validation_stream_winter)[:,0]
+all_val_files = [args.images_root+s[0:] for s in all_val_files ]
+args.all_winter_files = all_winter_train_files +args.test_stream_winter_files+args.validation_stream_winter_files
+
+
+# Optimizer Settings
+eps=1e-4
+weight_decay=75e-3
+args.optimizer = optim.AdamW(args.model.parameters(), eps=eps, weight_decay=weight_decay)
+args.criterion = LabelSmoothingCrossEntropy()
+
+with open(args.data_root+args.file+'_train_stream.pkl', 'rb') as handle: 
+    args.train_stream = pkl.load(handle)
+
+if args.shuffle_stream: 
+    train_exp_list_flatten = flattened(args.train_stream)
+    random.shuffle(train_exp_list_flatten)
+    shuffled_train_stream = []
+    exp_size = len(args.train_stream[0])
+    for i in range(len(args.train_stream)):
+        shuffled_train_stream.append(train_exp_list_flatten[i*exp_size:(i+1)*exp_size])
+    args.train_stream = shuffled_train_stream
+strategy = args.strategy
+
+
+# using parameters to assemble a name for the logging floder
+data_ratio =''
+rehersal=''
+if args.rehearsal_method is not None:
+    strategy = 'rehersal'
+    data_ratio = '_dataratio'+str(args.buffer_data_ratio)
+    rehersal = '_'+args.rehearsal_method
+    if ( args.rehearsal_method == 'ce' or args.rehearsal_method=='cf' )  and args.temperature is not None:
+        rehersal = rehersal+'_temp'+str(args.temperature)
+    if args.rehearsal_method == 'wce':
+        if args.nr_of_steps_to_avg is None or args.sigma is None:
+            print('ATTENTION: weighted moving average variables not set') 
+            print(args.nr_of_steps_to_avg)
+            print(args.sigma)
+        else: 
+            rehersal = rehersal+'_navg'+str(args.nr_of_steps_to_avg)+'_sigma'+str(args.sigma)
+    if args.memory_filling !='inf':
+        total_train_data = len(args.train_stream[0])*len(args.train_stream)
+        print(args.memory_size)
+        print(total_train_data)
+        if args.memory_size <1:
+            memory_size = int(args.memory_size*total_train_data)
+            print('r', memory_size)
+        else: 
+            memory_size = int(args.memory_size)
+            print('i', memory_size)
+        strategy = 'rehersal_' +args.memory_filling+'_memsize_'+str(args.memory_size)
+        print('total memory size ', args.memory_size)
+    else:
+        print('inf_mem')
+        print(args.memory_filling)
+if args.shuffle_stream:
+    strategy = strategy+'_shuffled' 
+# Logging rpaht
+log_root = Path('/home/boehlke/AMMOD/continual_learning/results/')
+if args.log_root is not None:
+    log_root = Path(args.log_root)
+model_settings = 'resnet'+str(model_depth)+'_smoothloss_adamW_eps'+str(eps)+'wd'+str(weight_decay)+'_bs'+str(args.batch_size)
+strategy_settings = strategy+rehersal+data_ratio+'_ep'+str(args.epochs)+'_ne'+str(args.eval_after_n_exp)
+args.log_dir_name = log_root / model_settings/ args.file / strategy_settings
+print(args.log_dir_name)
+Path.mkdir(args.log_dir_name, parents=True, exist_ok=True)
+
+
+cls_in_teststream =np.unique(np.array(args.test_stream)[:,1])
+if cls_in_teststream.shape[0] != args.num_classes:
+    print('ATTENTION: args.num_classes doesnt mathc number of classes in test_stream') 
+    print(args.num_classes)
+
+    #exit(0)
+
+cls_in_valstream =np.unique(np.array(args.val_stream)[:,1])
+if cls_in_valstream.shape[0] != args.num_classes:
+    print('ATTENTION: args.num_classes doesnt mathc number of classes in val_stream') 
+    print(args.num_classes)
+    #exit(0)
+
+
+
+# Special case for 
+with open(args.label_dict, 'rb') as p:
+    label_dict = pkl.load(p)
+print(label_dict.items())
+
+if not 0 in label_dict.keys(): # when the class names arte the keys an the integer values are the values
+    print('label_dict wrong way around')
+    flip_lable_dict = {}
+    for key,value in label_dict.items():
+        flip_lable_dict[value] = key
+
+    label_dict = flip_lable_dict
+
+
+
+
+# augmentations for train and val/test data
+resize = int(224/0.875)
+args.data_transforms = {
+    'train': transforms.Compose([
+        transforms.Resize((resize,resize)),
+        transforms.RandomCrop(size=(224,224)),
+        transforms.RandomHorizontalFlip(), 
+        transforms.ToTensor(),
+        #transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), 
+        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+    ]),
+    'val': transforms.Compose([
+        transforms.Resize((224,224)),
+        transforms.ToTensor(),
+        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+    ]),
+    }
+
+
+
+
+
+def get_scenario(args):
+
+    '''Returns a benchmark instance given a sequence of lists of files. A separate
+    dataset will be created for each list. Each of those datasets
+    will be considered a separate experience. 
+
+    function paths_benchmark() is defined in /avalanche/scenarios/generic_benchmark_creation.py 
+    under "create_generic_benchmark_from_paths()" which has been augmented from 
+    the original avalanche version with a path_dataset_class variable which allows for handling of data 
+    with more than path and label infomation. In this case, the sequence code is passed as well to allow
+    for sequencewise evaluation.
+    '''
+    
+    if args.strategy == 'joint':
+        joint_train_stream = [item for sublist in args.train_stream for item in sublist ]
+        args.train_stream = [joint_train_stream]
+        
+    task_labels = np.zeros((1,len(args.train_stream))).astype(int).tolist()[0]
+
+    nac =torch.tensor(args.non_animal_cls, device=args.detected_device)
+    nc = args.num_classes
+
+    metrics = [
+        #TrainingTime(reset_at='epoch', emit_at='epoch', mode='train'), 
+        #accuracy_metrics(minibatch=False, epoch=True, experience=False, stream=True),
+        Top1AccTransfer(nr_train_exp=len(args.train_stream), reset_at='stream', emit_at='stream', mode='eval'),
+        ClassTop_nAvgAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
+        #ClassTop_nAvgAcc(nr_classes=nc, n=1,reset_at='epoch', emit_at='epoch', mode='train'),
+        ClassTop_nAvgAcc(nr_classes=nc, n=2, reset_at='stream', emit_at='stream', mode='eval'),
+        SeasonClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
+        SeasonClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
+        #ClassTop_nAvgAcc(nr_classes=nc, n=2,reset_at='epoch', emit_at='epoch', mode='train'),
+        ClasswiseTop_nAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
+        Top_nAcc(n=1, reset_at='stream', emit_at='stream', mode='eval'),
+        SeasonTop_nAcc(n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
+        SeasonTop_nAcc(n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
+        #Top_nAcc(n=2, reset_at='stream', emit_at='stream', mode='eval'),
+        #Top_nAcc(n=1, reset_at='epoch', emit_at='epoch', mode='train'),
+        #Top_nAcc(n=2, reset_at='epoch', emit_at='epoch', mode='train'),
+        SeqTop_nAcc(n=1, reset_at='stream', emit_at='stream', mode='eval'),
+        SeqMaxAcc( reset_at='stream', emit_at='stream', mode='eval'),
+        SeasonSeqTop_nAcc(num=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
+        SeasonSeqTop_nAcc(num=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
+        #SeqTop_nAcc(n=2, reset_at='stream', emit_at='stream', moall_winter_filesdall_winter_filese='eval'),
+        #SeqTop_nAcc(n=3, reset_at='stream', emit_at='stream', mode='eval'),
+        #SeqTop_nAcc(n=1, reset_at='epoch', emit_at='epoch', mode='train'),
+        #SeqTop_nAcc(n=2, reset_at='epoch', emit_at='epoch', mode='train'),
+        #SeqTop_nAcc(n=3, reset_at='epoch', emit_at='epoch', mode='train'),  
+        SeqClasswiseTop_nAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
+        SeqClassTop_nAvgAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
+        SeqClassAvgMaxAcc(nr_classes=nc, reset_at='stream', emit_at='stream', mode='eval'),
+        SeasonSeqClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
+        SeasonSeqClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
+        #SeqClassTop_nAvgAcc(nr_classes=nc, n=1,reset_at='epoch', emit_at='epoch', mode='train'),
+        #SeqClassTop_nAvgAcc(nr_classes=nc, n=2, reset_at='stream', emit_at='stream', mode='eval'),
+        #SeqClassTop_nAvgAcc(nr_classes=nc, n=2,reset_at='epoch', emit_at='epoch', mode='train'),
+        SeqAnyAcc(reset_at='stream', emit_at='stream', mode='eval'),
+        #SeqAnyAcc(reset_at='epoch', emit_at='epoch', mode='train'),
+        #BinaryAnimalAcc(non_animal_class =nac, reset_at='epoch', emit_at='epoch', mode='train'),
+        #BinaryAnimalAcc(non_animal_class =nac,reset_at='stream', emit_at='stream', mode='eval'),
+        #QuatrupleAnimalAcc(non_animal_class=nac, reset_at='epoch', emit_at='epoch', mode='train'),
+        #QuatrupleAnimalAcc(non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
+        #SeqBinaryAnimalAcc(non_animal_class=nac,reset_at='epoch', emit_at='epoch', mode='train'),
+        #SeqBinaryAnimalAcc(non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
+        #SeqIgnoreNonAnimalOutputsAcc(non_animal_class=nac,nr_classes=nc, reset_at='epoch', emit_at='epoch', mode='train'),
+        #SeqIgnoreNonAnimalOutputsAcc(non_animal_class=nac,nr_classes=nc,reset_at='stream', emit_at='stream', mode='eval'),
+        #ClassIgnoreNonAnimalOutputsTop_nAvgAcc(nr_classes=nc,n=1,non_animal_class=nac,reset_at='epoch', emit_at='epoch', mode='train'),
+        #ClassIgnoreNonAnimalOutputsTop_nAvgAcc(nr_classes=nc,n=1, non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
+        #IgnoreNonAnimalOutputsTop_nAcc(non_animal_class=nac, nr_classes=nc, n=1,  reset_at='epoch', emit_at='epoch', mode='train'),
+        #IgnoreNonAnimalOutputsTop_nAcc(non_animal_class=nac,nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
+        #NrStepsBatchwise(reset_at='epoch', emit_at='epoch', mode='train'),
+        #NrStepsImagewise(reset_at='epoch', emit_at='epoch', mode='train')
+               ]
+
+
+    metrics = flattened(metrics)
+
+    datetime_stamp = str(datetime.datetime.now()).replace(' ','_').replace(':','-').replace('.','-')[:-4]
+    tb_logger = TensorboardLogger(tb_log_dir=str(args.log_dir_name)+"/tb_data_"+datetime_stamp, filename_suffix='test_run')
+    interactive_logger = InteractiveLogger()
+    gen_csv_logger = GenericCSVLogger(log_folder=str(args.log_dir_name)+'/gen_csvlogs_'+datetime_stamp)
+
+    args.combined_logger = EvaluationPlugin(
+        metrics,
+        loggers=[tb_logger, gen_csv_logger, interactive_logger],
+        suppress_warnings=False)
+
+
+    scenario = paths_benchmark(
+           args.train_stream,
+           [args.test_stream],  
+           other_streams_lists_of_files={'validation': [args.val_stream]},
+           task_labels=task_labels,
+           complete_test_set_only=True,
+           train_transform=args.data_transforms['train'],
+           eval_transform=args.data_transforms['val'],
+           other_streams_transforms={'validation_stream': args.data_transforms['val']},
+           path_dataset_class=SeqPathsDataset,
+           common_root = args.images_root
+        )
+
+    return scenario
+
+def get_strategy(args):
+
+    ''' Returns the avalanche training strategy with the necessary plugins  '''
+    plugin = SeqDataPlugin()
+    plugins = [plugin]
+
+    if args.memory_filling=='inf':
+        if args.rehearsal_method == 'ce':
+            if args.temperature is None:
+                plugin_rehearsal = ClassErrorRehersalPlugin(buffer_data_ratio= args.buffer_data_ratio)
+                plugins.append(plugin_rehearsal)
+            else:
+                plugin_rehearsal = ClassErrorRehersalTemperaturePlugin(buffer_data_ratio= args.buffer_data_ratio, temperature=args.temperature)
+                plugins.append(plugin_rehearsal)
+        elif args.rehearsal_method == 'cf':
+            plugin_rehearsal = ClassFrequencyRehearsalPlugin(buffer_data_ratio = args.buffer_data_ratio)
+            plugins.append(plugin_rehearsal)
+        elif args.rehearsal_method == 'mir':
+            plugin_rehearsal = MaximallyInterferedRetrievalRehersalPlugin(buffer_data_ratio=args.buffer_data_ratio)
+            plugins.append(plugin_rehearsal)
+        elif args.rehearsal_method == 'rr':
+            plugin_rehearsal = RandomRehersal(buffer_data_ratio=args.buffer_data_ratio)
+            plugins.append(plugin_rehearsal) 
+        elif args.rehearsal_method == 'wce':
+            plugin_rehearsal = WeightedeMovingClassErrorAverageRehersalPlugin( buffer_data_ratio=args.buffer_data_ratio, nr_of_steps_to_avg=args.nr_of_steps_to_avg, nr_classes=args.num_classes, sigma=args.sigma)
+            plugins.append(plugin_rehearsal)
+        elif args.rehearsal_method == 'cefa':
+            plugin_rehearsal = ClassErrorFrequencyAvgRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
+            plugins.append(plugin_rehearsal)
+        elif args.rehearsal_method == 'fe':
+            plugin_rehearsal = FillExpBasedRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
+            plugins.append(plugin_rehearsal)
+        elif args.rehearsal_method == 'feo':
+            plugin_rehearsal = FillExpOversampleBasedRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
+            plugins.append(plugin_rehearsal)
+    elif args.memory_filling == 'cbrs':
+        plugin_rehearsal = ClassBalancingReservoirMemoryRehersalPlugin(buffer_data_ratio=args.buffer_data_ratio, memory_size=memory_size, nr_classes=args.num_classes, rehearsal_selection_strategy=args.rehearsal_method, temperature=args.temperature)
+        plugins.append(plugin_rehearsal)
+    elif args.memory_filling == 'stdrs':
+        plugin_rehearsal = ReservoirMemoryRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, memory_size=memory_size, nr_classes=args.num_classes, rehearsal_selection_strategy=args.rehearsal_method, temperature=args.temperature)
+        plugins.append(plugin_rehearsal)
+        print('no rehearsal method selected')
+    if args.strategy == 'joint':
+        ''' In the joint baseline scenario the naive base strategy defined in /avalanche/training/strategies/base_strategies.py 
+        is used 
+        ''' 
+        print('getting joint strategy')
+        strategy = Naive(
+            args.model, 
+            args.optimizer, 
+            args.criterion, 
+            train_mb_size=args.batch_size, 
+            train_epochs=1,  
+            eval_mb_size=args.exp_size, 
+            device=args.detected_device, 
+            plugins=[plugin], 
+            evaluator=args.combined_logger,
+            label_dict=label_dict
+        )
+
+    elif args.strategy == 'cumulative':
+        print('getting cumulative strategy')
+        strategy = Cumulative(
+            args.model, 
+            args.optimizer, 
+            args.criterion, 
+            train_mb_size=args.batch_size, 
+            train_epochs=args.epochs, 
+            eval_mb_size=args.exp_size, 
+            device=args.detected_device, 
+            plugins=[plugin], 
+            evaluator=args.combined_logger,
+            label_dict=label_dict
+        )
+    else:
+        print('getting naive strategy')
+        strategy = Naive(
+            args.model, 
+            args.optimizer, 
+            args.criterion,  
+            train_mb_size=args.batch_size, 
+            train_epochs=args.epochs, 
+            eval_mb_size=args.exp_size, 
+            device=args.detected_device, 
+            plugins=plugins, 
+            evaluator=args.combined_logger,
+            label_dict=label_dict
+        )
+    return strategy
+
+
+
+
+scenario = get_scenario(args)
+cl_strategy = get_strategy(args)
+
+
+
+# different loops are necessary for different strategies. 
+# when using class error based method, the evaluation on the validation data should occure more frequently, meaning the variable eval_after_n_exp (-ne) should be less than 10
+if args.strategy == 'joint':
+
+    for i in range(args.epochs):
+        #print('joint ', dir(cl_strategy) )
+        cl_strategy.train(scenario.train_stream[0], num_workers=number_workers)
+        print('Training completed')
+        last_eval = i ==args.epochs-1
+
+        if i%args.test_eval_after_n_exp==0 or last_eval:
+            print('Computing accuracy on the whole test set')
+            cl_strategy.eval(scenario.test_stream, num_workers=number_workers, last_eval=True)
+        if i%args.eval_after_n_exp==0 or last_eval:
+            print('Computing accuracy on the whole validation set')
+            cl_strategy.eval(scenario.validation_stream, num_workers=number_workers, last_eval=True)
+        if last_eval:
+            cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=False)
+
+
+else:
+    train_stream_len = len(scenario.train_stream)
+    for i, experience in enumerate(scenario.train_stream):
+        cl_strategy.train(experience, num_workers=number_workers)
+        print('Training completed')
+        last_eval = i ==train_stream_len-1
+        if (i%args.eval_after_n_exp==0 or last_eval):
+            print('Evaluaton after experience: ', i )
+            #cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=True)
+            print('Computing accuracy on the whole validation set')
+            cl_strategy.eval(scenario.validation_stream, num_workers=number_workers, last_eval=True)
+            
+
+        if (i%args.test_eval_after_n_exp==0 or last_eval):
+            print('Evaluaton after experience: ', i )
+            print('Computing accuracy on the whole test set')
+            cl_strategy.eval(scenario.test_stream, num_workers=number_workers, last_eval=True)
+
+        if last_eval:
+
+            cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=True)

+ 96 - 0
scripts/example_script.sh

@@ -0,0 +1,96 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=0
+SCRIPT="cont_ava.py"
+echo $SCRIPT
+BATCH_SIZE=128
+num_classes=11
+
+LOG="../example_LOG/"
+data_root="../data/data_stream_files/BW_stream_files/"
+images_root="/home/AMMOD_data/camera_traps/BayerWald/G-Fallen/MDcrops/"
+label_dict="../data/label_dictionaries/BIRDS_11_Species.pkl"
+
+
+
+#nested loop to iterate over all cross validations splits
+# 
+# for eps in 5 # epochs
+# do
+# for exp_size in 128 
+# do
+# for i in  0 1 2 3 4 # cross validation splits
+# do
+# FILE="cv${i}_expsize${exp_size}_crop"
+# for dr in 8  #data ratio
+# do
+# 
+# cf => class inverse frequency rehearsal
+# python $SCRIPT --file $FILE --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG -rm cf -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 100  --label_dict $label_dict 
+# (shuffled)
+# python $SCRIPT --file $FILE --shuffle_stream --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG -rm cf -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 100  --label_dict $label_dict
+# 
+# rr => random rehearsal (nuiform samling)
+# python $SCRIPT --file $FILE --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG -rm rr -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 100  --label_dict $label_dict 
+# (shuffled)
+# python $SCRIPT --file $FILE --shuffle_stream --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG -rm rr -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 100  --label_dict $label_dict
+# 
+# mir => maximally intergfered retrieval based rehearsal
+# python $SCRIPT --file $FILE --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG -rm mir -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 100  --label_dict $label_dict 
+# (shuffled)
+# python $SCRIPT --file $FILE --shuffle_stream --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG -rm mir -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 100  --label_dict $label_dict
+# 
+# ce => class error based rehearsal (with eval_after_n_exp, i.e. ne equal to 5) 
+# python $SCRIPT --file $FILE --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG -rm ce -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 5 -tne 100 -temp $temp  --label_dict $label_dict 
+# (shuffled)
+# python $SCRIPT --file $FILE --shuffle_stream --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG -rm ce -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 5 -tne 100 -temp $temp  --label_dict $label_dict
+# 
+# weighted moving average of class error based method with 18 past evaluations on the validation data influencing weights and weights drawn from a gaussian bell curve with sigma 9 
+# python $SCRIPT --file $FILE --shuffle_stream --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG -rm wce -nexp_avg 18 -sig 9  -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 5 -tne 100 --label_dict $label_dict
+# (shuffled)
+# python $SCRIPT --file $FILE --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG -rm wce -nexp_avg 18 -sig 9  -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 5 -tne 100  --label_dict $label_dict 
+# 
+# done
+# done
+# done
+# done
+# 
+
+
+
+
+
+
+# nested loop to perform experiments with limited memory and different memory filling strategies:
+
+for eps in 5 # epochs
+do
+for exp_size in 128 
+do
+for i in 0 1 2 3 4 # cross validation split
+do
+FILE="cv${i}_expsize${exp_size}_crop"
+for dr in 2 8 # data ratio  
+do
+for mem in 0.1 0.25 # memory size one tenth and one quatre of total stream data. 
+do 
+for memf_strategy in cbrs stdrs
+do 
+
+python $SCRIPT --file $FILE --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG   -rm ce -memf $memf_strategy -mems $mem -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 5  -tne 100 --label_dict $label_dict
+python $SCRIPT --file $FILE --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG   -rm cf -memf $memf_strategy -mems $mem -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 100 --label_dict $label_dict
+python $SCRIPT --file $FILE --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG   -rm rr -memf $memf_strategy -mems $mem -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 100 --label_dict $label_dict
+python $SCRIPT --file $FILE --shuffle_stream --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG   -rm ce -memf $memf_strategy -mems $mem -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 5 --label_dict $label_dict
+python $SCRIPT --file $FILE --shuffle_stream --strategy naive --data_root $data_root --images_root $images_root -cnum_classes $num_classes -log $LOG   -rm cf -memf $memf_strategy -mems $mem -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 100 --label_dict $label_dict
+python $SCRIPT --file $FILE --shuffle_stream --strategy naive --data_root $data_root --images_root $images_root --num_classes $num_classes -log $LOG   -rm rr -memf $memf_strategy -mems $mem -dr $dr --batch_size $BATCH_SIZE --epochs $eps -ne 100 --label_dict $label_dict
+done
+
+done
+done
+done
+done
+done
+
+
+
+
+

+ 6 - 5
scripts/jupyter_notbooks/introduction_to_avalanche.ipynb

@@ -602,7 +602,6 @@
    "outputs": [],
    "source": [
     "val_stream_file = data_dir_path+'data_stream_files/BW_stream_files/cv0_expsize128_crop_val_stream.pkl'\n",
-    "label_dict_file = data_dir_path+'label_dictionaries/BIRDS_11_Species.pkl'\n",
     "with open(val_stream_file, 'rb') as handle: \n",
     "    val_stream = pkl.load(handle)\n",
     "# The val_stream is passed to the pahts_benchmark function below, that creats the 'scenario'"
@@ -730,11 +729,13 @@
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
+   "cell_type": "markdown",
    "metadata": {},
-   "outputs": [],
-   "source": []
+   "source": [
+    "## Bringing it all together for cross-validated Experiments\n",
+    "\n",
+    "In order to perform experiments on all cross-validation datasets a script was used which reads arguments from the command line withwhich different parameters and options can be set for all the continual learning strategies implemented. The script is found in scripts/cont_ava.py. The scripts/example_script shows how this script was used to run experiments with different settings on multiple cross validation scripts. The names of the arguments and a brief statement on what they are can be seen with python cont_ava.py --help or in the file itself. "
+   ]
   }
  ],
  "metadata": {