Pārlūkot izejas kodu

rhearsal learning and limited memory filling togehter with updates to introductory notebook

Julia Boehlke 2 gadi atpakaļ
vecāks
revīzija
3682c6c83c

+ 1 - 1
avalanche/avalanche/benchmarks/scenarios/generic_benchmark_creation.py

@@ -134,7 +134,7 @@ def create_multi_dataset_generic_benchmark(
                 dataset,
                 transform_groups=transform_groups,
                 initial_transform_group=initial_transform_group,
-                dataset_type=dataset_type))
+                dataset_type=dataset_type, paths = dataset.paths))
         stream_definitions[stream_name] = (stream_datasets,)
 
     return GenericCLScenario(

+ 30 - 2
avalanche/avalanche/models/utils.py

@@ -1,6 +1,6 @@
 from avalanche.models.dynamic_modules import MultiTaskModule
 import torch.nn as nn
-
+from torch.nn import functional as F
 
 def avalanche_forward(model, x, task_labels):
     if isinstance(model, MultiTaskModule):
@@ -51,7 +51,35 @@ class FeatureExtractorBackbone(nn.Module):
             self.get_activation())
 
 
+def linear_combination(x, y, epsilon):
+    return epsilon * x + (1 - epsilon) * y
+
+def reduce_loss(loss, reduction="mean"):
+    if reduction == "mean":
+        return loss.mean()
+    
+    elif reduction == "sum":
+        return loss.sum()
+    
+    else:
+        return loss
+    
+class LabelSmoothingCrossEntropy(nn.Module):
+    def __init__(self, epsilon: float = 0.1, reduction="mean"):
+        super(LabelSmoothingCrossEntropy, self).__init__()
+        self.epsilon = epsilon
+        self.reduction = reduction
+    def forward(self, preds, target):
+        n_classes = preds.size()[-1]
+        
+        log_preds = F.log_softmax(preds, dim=-1)
+        loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
+        nll = F.nll_loss(log_preds, target, reduction=self.reduction)
+        
+        return linear_combination(loss / n_classes, nll, self.epsilon)
+
 __all__ = [
     'avalanche_forward',
-    'FeatureExtractorBackbone'
+    'FeatureExtractorBackbone',
+    'LabelSmoothingCrossEntropy'
 ]

+ 11 - 0
avalanche/avalanche/training/plugins/__init__.py

@@ -7,7 +7,18 @@ from .gem import GEMPlugin
 from .lwf import LwFPlugin
 from .replay import ReplayPlugin, StoragePolicy, ClassBalancedStoragePolicy, \
     ExperienceBalancedStoragePolicy
+from .imbalance_focus_replay import (MaximallyInterferedRetrievalRehersalPlugin, 
+                                    RandomRehersal, 
+                                    ClassErrorRehersalPlugin, 
+                                    ClassErrorRehersalTemperaturePlugin, 
+                                    ClassFrequencyRehearsalPlugin, 
+                                    ClassImbalanceRehersalPlugin,
+                                    WeightedeMovingClassErrorAverageRehersalPlugin,
+                                    ClassErrorFrequencyAvgRehearsalPlugin,
+                                    FillExpBasedRehearsalPlugin,
+                                    FillExpOversampleBasedRehearsalPlugin)
 from .strategy_plugin import StrategyPlugin
+from .class_balancing_memory import ClassImbalanceMemoryRehersalPlugin, ClassBalancingReservoirMemoryRehersalPlugin, ReservoirMemoryRehearsalPlugin
 from .synaptic_intelligence import SynapticIntelligencePlugin
 from .cope import CoPEPlugin, PPPloss
 from .sequence_data import SeqDataPlugin

+ 398 - 0
avalanche/avalanche/training/plugins/class_balancing_memory.py

@@ -0,0 +1,398 @@
+import random
+from abc import ABC, abstractmethod
+from typing import Dict, List, Optional, TYPE_CHECKING
+
+import torch
+from numpy import inf
+import numpy as np
+from torch import Tensor, cat
+from torch.nn import Module, CrossEntropyLoss
+from torch.utils.data import random_split, DataLoader
+from torch.optim import Optimizer, SGD, AdamW
+
+from avalanche.benchmarks.utils import AvalancheConcatDataset, \
+	AvalancheDataset, AvalancheSubset
+from avalanche.benchmarks.utils.data_loader import \
+	ReplayDataLoader
+from avalanche.models import FeatureExtractorBackbone
+from avalanche.training.plugins.strategy_plugin import StrategyPlugin
+
+if TYPE_CHECKING:
+	from avalanche.training.strategies import BaseStrategy
+
+from avalanche.models.utils import avalanche_forward
+from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader, _seq_collate_mbatches_fn
+from avalanche.models.dynamic_optimizers import reset_optimizer
+import copy
+import numpy as np
+from avalanche.models.utils import LabelSmoothingCrossEntropy
+
+
+
+def _get_class_error_based_indicies(strategy, rehearsal_plugin):
+
+	nr_indices_to_return = int(len(strategy.experience.dataset)*rehearsal_plugin.buffer_data_ratio)
+	if len(rehearsal_plugin.memory_dataset)<= nr_indices_to_return:
+		return None
+
+	assert strategy.val_cls_acc_dict is not None
+
+	idx_rehersal_prob = torch.zeros((len(rehearsal_plugin.memory_dataset)))
+	for cls_i, cls_acc_i in strategy.val_cls_acc_dict.items():
+		if isinstance(cls_acc_i.result(), dict) and 0 in cls_acc_i.result().keys() : 
+			cls_acc = cls_acc_i.result()[0]
+			cls_rehersal_prob= 1-cls_acc
+		else: 
+			cls_acc = 0
+
+		cls_rehersal_prob= 1-cls_acc
+		total_seen_for_cls = len(idx_rehersal_prob[rehearsal_plugin.cls_idx_dict[cls_i]])
+
+		if total_seen_for_cls!=0:
+			idx_rehersal_prob[rehearsal_plugin.cls_idx_dict[cls_i]] = cls_rehersal_prob/total_seen_for_cls
+		else:
+			idx_rehersal_prob[rehearsal_plugin.cls_idx_dict[cls_i]] = 0
+
+	idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
+	return idxs
+
+
+def _get_class_error_temp_based_indicies(strategy, rehearsal_plugin):
+
+	nr_indices_to_return = int(len(strategy.experience.dataset)*rehearsal_plugin.buffer_data_ratio)
+	if len(rehearsal_plugin.memory_dataset)<= nr_indices_to_return:
+		return None
+
+	assert strategy.val_cls_acc_dict is not None
+
+	idx_rehersal_prob = torch.zeros((len(rehearsal_plugin.memory_dataset)))
+
+	for cls_i, cls_acc_i in strategy.val_cls_acc_dict.items():
+		if isinstance(cls_acc_i.result(), dict) and 0 in cls_acc_i.result().keys() : 
+			cls_acc = cls_acc_i.result()[0]
+			cls_rehersal_prob= 1-cls_acc
+		else:  
+			cls_acc = 0
+		cls_rehersal_prob= torch.exp(torch.tensor((1-cls_acc)/rehearsal_plugin.temperature))
+		total_seen_for_cls = len(rehearsal_plugin.cls_idx_dict[cls_i])
+		if total_seen_for_cls!=0:
+			idx_rehersal_prob[rehearsal_plugin.cls_idx_dict[cls_i]] = cls_rehersal_prob/total_seen_for_cls
+		else:
+			idx_rehersal_prob[rehearsal_plugin.cls_idx_dict[cls_i]] = 0
+
+	idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
+	return idxs
+
+
+def _get_class_random_indicies(strategy, rehearsal_plugin):
+	nr_indices_to_return = int(len(strategy.experience.dataset)*rehearsal_plugin.buffer_data_ratio)
+	if len(rehearsal_plugin.memory_dataset)<= nr_indices_to_return:
+		return None
+	else:
+		idxs = random.sample(range(len(rehearsal_plugin.memory_dataset)), nr_indices_to_return)
+		idxs = torch.tensor(idxs)
+		return idxs
+
+def _get_frequency_based_indicies(strategy, rehearsal_plugin):
+
+	nr_indices_to_return = int(len(strategy.experience.dataset)*rehearsal_plugin.buffer_data_ratio)
+	if len(rehearsal_plugin.memory_dataset)<= nr_indices_to_return:
+		return None
+
+	idx_rehersal_prob = torch.zeros((len(rehearsal_plugin.memory_dataset)))
+	total_data_len = len(rehearsal_plugin.memory_dataset)
+	summed_inv_cls_freq = 0
+	for cls_i, c_idxs  in rehearsal_plugin.cls_idx_dict.items():
+		if c_idxs.shape[0] != 0:
+			inv_cls_freq = 1/c_idxs.shape[0]
+		else:
+                    inv_cls_freq = 0
+		summed_inv_cls_freq += inv_cls_freq
+		total_seen_for_cls = len(idx_rehersal_prob[rehearsal_plugin.cls_idx_dict[cls_i]])
+		if total_seen_for_cls != 0 :
+			idx_rehersal_prob[rehearsal_plugin.cls_idx_dict[cls_i]] = inv_cls_freq/total_seen_for_cls
+		else:
+			idx_rehersal_prob[rehearsal_plugin.cls_idx_dict[cls_i]] = 0
+
+	idx_rehersal_prob = idx_rehersal_prob/summed_inv_cls_freq
+	idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
+	return idxs
+
+
+def _get_frequency_temp_based_indicies(strategy, rehearsal_plugin):
+
+	nr_indices_to_return = int(len(strategy.experience.dataset)*rehearsal_plugin.buffer_data_ratio)
+	if len(rehearsal_plugin.memory_dataset)<= nr_indices_to_return:
+		return None
+
+	idx_rehersal_prob = torch.zeros((len(rehearsal_plugin.memory_dataset)))
+	total_data_len = len(rehearsal_plugin.memory_dataset)
+	summed_inv_cls_freq = 0
+	for cls_i, c_idxs  in rehearsal_plugin.cls_idx_dict.items():
+		if c_idxs.shape[0] != 0:
+			inv_cls_freq = 1/c_idxs.shape[0]
+			inv_cls_freq = torch.exp(torch.tensor(inv_cls_freq/rehearsal_plugin.temperature))
+		else:
+			inv_cls_freq = 0
+		summed_inv_cls_freq += inv_cls_freq
+		total_seen_for_cls = len(idx_rehersal_prob[rehearsal_plugin.cls_idx_dict[cls_i]])
+		if total_seen_for_cls != 0 :
+			idx_rehersal_prob[rehearsal_plugin.cls_idx_dict[cls_i]] = inv_cls_freq/total_seen_for_cls
+		else:
+			idx_rehersal_prob[rehearsal_plugin.cls_idx_dict[cls_i]] = 0
+
+	idx_rehersal_prob = idx_rehersal_prob/summed_inv_cls_freq
+	idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
+	return idxs
+
+
+
+global rehearsal_selection_strategy_dict
+rehearsal_selection_strategy_dict = {'ce':_get_class_error_based_indicies,
+									 'cet': _get_class_error_temp_based_indicies,
+									 'cft': _get_frequency_temp_based_indicies,
+									 'cf': _get_frequency_based_indicies,
+									 'rr': _get_class_random_indicies }
+
+class ClassImbalanceMemoryRehersalPlugin(StrategyPlugin):
+	# This is the parent class of plugins that handle a limited memory and the filling of that memory. 
+	# The memory_dataset variable stores the data that is saved. 
+	# The variable memory_size is an integer variable defining how many instances can be stored
+	# in the memory. 
+	# The rehearsal_selection_strategy has to be a string code defining how the data from the memory should 
+	# selceted for the rehearsal set. The global rehearsal_selection_strategy_dict selects the correct function
+	# for returning the the indicies for the memry picked for rehearsal. 
+
+	# 
+
+	def __init__(self, buffer_data_ratio, memory_size, rehearsal_selection_strategy, temperature):
+		super().__init__()
+	
+		self.memory_dataset = None
+		self.cls_idx_dict = {}
+		self.buffer_data_ratio = buffer_data_ratio
+		self.memory_size = memory_size
+		if temperature is not None and rehearsal_selection_strategy=='ce':
+			rehearsal_selection_strategy = 'cet'
+		if temperature is not None and rehearsal_selection_strategy=='cf':
+			rehearsal_selection_strategy = 'cft'
+		self.temperature = temperature
+		self.rehearsal_selection_strategy = rehearsal_selection_strategy_dict[rehearsal_selection_strategy]
+
+
+	def train_dataset_adaptation(self, strategy: 'BaseStrategy', **kwargs):
+		
+		indices = None
+		if self.memory_dataset is not None:
+
+			indices = self._get_indices(strategy, **kwargs)
+			if indices is None:
+				self.exp_rehersal_set = self.memory_dataset # until r experiences have been seen _get_indicies functions return None
+			else:
+				self.exp_rehersal_set = AvalancheSubset(self.memory_dataset, indices=indices)
+
+			strategy.adapted_dataset = AvalancheConcatDataset([strategy.experience.dataset, self.exp_rehersal_set], paths=[strategy.experience.dataset.paths, self.exp_rehersal_set.paths])
+	
+		else:
+			strategy.adapted_dataset = AvalancheConcatDataset([strategy.experience.dataset], paths=strategy.experience.dataset.paths) # in first experience
+
+
+		strategy.rehearsal_indicies_picked[strategy.experience.current_experience] = indices
+
+
+class ReservoirMemoryRehearsalPlugin(ClassImbalanceMemoryRehersalPlugin):
+
+	# This plugin is an adaptation of the reservoir sampling algorithm from a stream of data for batches. The original algorithm 
+	# is designed to sample uniformly from a stream, meaning at the end of the probability of each stream element to be in the memory 
+	# is exactly 1/nr_elements_seen_so_far.  
+	# This is a batch-wise adaptation of this algorithm to avoid looping over every element in the experience. 
+	# This memory filling strategy can be seen as a baseline and compared to the class balancing reservoir sampling method described below.
+
+	def __init__(self, buffer_data_ratio, memory_size, nr_classes, rehearsal_selection_strategy, temperature):
+		super().__init__(buffer_data_ratio, memory_size, rehearsal_selection_strategy, temperature )
+		
+		self.nr_classes = nr_classes
+		self.total_stream_length = 0
+		
+
+
+	def after_training_exp(self, strategy: 'BaseStrategy', **kwargs):
+
+		if self.memory_dataset is None:
+			self.memory_dataset = strategy.experience.dataset
+			self.total_stream_length = len(strategy.experience.dataset)
+
+		elif len(self.memory_dataset)+len(strategy.experience.dataset) <= self.memory_size:
+			combined_paths = self.memory_dataset.paths+strategy.experience.dataset.paths
+
+			self.memory_dataset = AvalancheConcatDataset(
+				[self.memory_dataset, strategy.experience.dataset], paths=combined_paths)
+			self.total_stream_length += len(strategy.experience.dataset)
+
+		else:
+			if len(self.memory_dataset)==self.memory_size:
+				experience_data = strategy.experience.dataset
+
+			elif len(self.memory_dataset)<self.memory_size:
+
+				experience_fill_memory = AvalancheSubset(strategy.experience.dataset, indices=range(self.memory_size-len(self.memory_dataset)), paths=strategy.experience.dataset.paths)
+				experience_data = AvalancheSubset(strategy.experience.dataset, indices=range(self.memory_size-len(self.memory_dataset), len(strategy.experience.dataset)), paths=strategy.experience.dataset.paths)
+				memory_dataset_paths = self.memory_dataset.paths+experience_fill_memory.paths
+				self.memory_dataset = AvalancheConcatDataset([self.memory_dataset, experience_fill_memory ], paths=memory_dataset_paths)
+
+				self.total_stream_length += len(experience_fill_memory)#
+
+
+			stream_positions = range(self.total_stream_length, self.total_stream_length+len(experience_data))
+
+			self.total_stream_length = self.total_stream_length+len(experience_data)
+			random_M = [random.randint(0, i) for i in stream_positions]
+			random_M = torch.tensor(random_M)
+
+			exp_idx_less_than_mem_size = (random_M<self.memory_size).nonzero().flatten()
+
+			mem_idx_to_replace = np.random.choice(range(self.memory_size), size=exp_idx_less_than_mem_size.shape[0], replace=True)
+
+			unique_mem_to_replace, index_exp_to_replace = np.unique(mem_idx_to_replace[::-1], return_index=True)#
+
+			index_exp_to_replace = -index_exp_to_replace +(len(exp_idx_less_than_mem_size)-1) 
+
+
+			mem_idx_to_retain = np.array(range(self.memory_size))[torch.logical_not(torch.isin(torch.arange(self.memory_size), torch.tensor(unique_mem_to_replace)))]
+			exp_subset = AvalancheSubset(experience_data, indices=index_exp_to_replace, paths=experience_data.paths)
+			mem_subset = AvalancheSubset(self.memory_dataset, indices=mem_idx_to_retain, paths=self.memory_dataset.paths)
+
+			new_memory_paths = exp_subset.paths+ mem_subset.paths
+			self.memory_dataset = AvalancheConcatDataset([mem_subset, exp_subset],  paths=new_memory_paths)
+
+		all_memory_targets = torch.tensor(self.memory_dataset.targets)
+		strategy.memory_dataset_paths.append(self.memory_dataset.paths)
+		strategy.memory_dataset_targets.append(all_memory_targets)
+		for i in range(strategy.model.num_classes):
+			self.cls_idx_dict[i] = (all_memory_targets==i).nonzero().flatten()
+
+
+	def _get_indices(self, strategy):
+		return self.rehearsal_selection_strategy(strategy, self)
+
+
+class ClassBalancingReservoirMemoryRehersalPlugin(ClassImbalanceMemoryRehersalPlugin):
+
+	# This memory filling strategy was based on the Class-Balancing Reservoir Sampling Algorithm defined in the paper 
+	# 'Online Continual Learning from Imbalanced Data' by Chrysakis et al. The basic idea is to limit the replacement of old instances 
+	# in the memory to classes that are 'full'. Please see the paper for more tetails. This again is a batchwise adaptation of the 
+	# original algorithm.
+
+	def __init__(self, buffer_data_ratio, memory_size, nr_classes, rehearsal_selection_strategy, temperature):
+		super().__init__(buffer_data_ratio, memory_size, rehearsal_selection_strategy, temperature )
+
+		self.full_classes = torch.tensor([])
+		self.instances_in_stream_count_dict = {}
+		self.instances_in_memory_count_dict = {}
+		self.mem_to_stream_ratios_dict = {}
+		self.nr_classes = nr_classes
+
+		for i in range(self.nr_classes):
+			self.instances_in_memory_count_dict[i] = 0
+			self.instances_in_stream_count_dict[i] = 0 
+			self.mem_to_stream_ratios_dict[i] = np.Inf 
+
+
+
+	def after_training_exp(self, strategy: 'BaseStrategy', **kwargs):
+
+		for i in range(strategy.model.num_classes):
+
+			nr_cls_instances = (torch.tensor(strategy.experience.dataset.targets)==i).nonzero().flatten().shape[0]
+
+			if nr_cls_instances !=0:
+
+				self.instances_in_stream_count_dict[i] += nr_cls_instances
+				self.mem_to_stream_ratios_dict[i]=self.instances_in_memory_count_dict[i]/self.instances_in_stream_count_dict[i]
+
+			
+		if self.memory_dataset is None:
+			self.memory_dataset = strategy.experience.dataset
+		elif len(self.memory_dataset)+len(strategy.experience.dataset) <= self.memory_size:
+			combined_paths = self.memory_dataset.paths+strategy.experience.dataset.paths
+
+			self.memory_dataset = AvalancheConcatDataset(
+				[self.memory_dataset, strategy.experience.dataset], paths=combined_paths)
+
+
+		else:
+			if len(self.memory_dataset)==self.memory_size:
+
+				experience_data = strategy.experience.dataset
+			elif len(self.memory_dataset)<self.memory_size:
+
+				experience_fill_memory = AvalancheSubset(strategy.experience.dataset, indices=range(self.memory_size-len(self.memory_dataset)), paths=strategy.experience.dataset.paths)
+				experience_data = AvalancheSubset(strategy.experience.dataset, indices=range(self.memory_size-len(self.memory_dataset), len(strategy.experience.dataset)), paths=strategy.experience.dataset.paths)
+
+				memory_dataset_paths = self.memory_dataset.paths+experience_fill_memory.paths
+				self.memory_dataset = AvalancheConcatDataset([self.memory_dataset, experience_fill_memory ], paths=memory_dataset_paths)
+
+
+			all_memory_targets = torch.tensor(self.memory_dataset.targets)
+			u, c = torch.unique(all_memory_targets, return_counts=True)
+			most_freq_cls_in_mem = u[c==torch.max(c)]
+			self.full_classes = torch.unique(torch.cat([self.full_classes, most_freq_cls_in_mem]))
+
+			
+			
+			mem_subset_in_full_classes_idxs = (torch.isin(all_memory_targets, self.full_classes)).nonzero().flatten()
+			mem_subset_NOT_in_full_classes_idxs = (torch.logical_not(torch.isin(all_memory_targets, self.full_classes))).nonzero().flatten()
+
+
+			exp_subset_in_NOT_full_classes_idxs = (torch.logical_not(torch.isin(torch.tensor(experience_data.targets), self.full_classes))).nonzero().flatten()
+
+			idxs_from_large_classes_retained = random.sample(mem_subset_in_full_classes_idxs.tolist(), mem_subset_in_full_classes_idxs.shape[0]-exp_subset_in_NOT_full_classes_idxs.shape[0] )
+			large_classes_retained_dataset = AvalancheSubset(self.memory_dataset, indices=idxs_from_large_classes_retained, paths=self.memory_dataset.paths)
+			NOT_large_classes_memory_subset = AvalancheSubset(self.memory_dataset, indices=mem_subset_NOT_in_full_classes_idxs, paths=self.memory_dataset.paths)
+
+			new_memory = AvalancheSubset(experience_data, indices= exp_subset_in_NOT_full_classes_idxs, paths=experience_data.paths)
+			new_memory_paths =new_memory.paths+NOT_large_classes_memory_subset.paths
+			new_memory = AvalancheConcatDataset([new_memory, NOT_large_classes_memory_subset], paths=new_memory_paths)
+
+			for j, cls_i in enumerate(self.full_classes):
+
+				cls_i_in_mem_retained_idxs = (torch.tensor(large_classes_retained_dataset.targets)==cls_i).nonzero().flatten()#
+
+				cls_i_exp_subset_idxs =  (torch.tensor(experience_data.targets)==cls_i).nonzero().flatten()
+
+
+				uniform_samples = torch.rand(cls_i_exp_subset_idxs.shape[0])
+
+				exp_idxs_for_replacing_mem = (uniform_samples<=self.mem_to_stream_ratios_dict[int(cls_i.item())]).nonzero().flatten()
+				mem_idx_to_replace = np.random.choice(cls_i_in_mem_retained_idxs, size=exp_idxs_for_replacing_mem.shape[0], replace=True)
+				unique_mem_to_replace, index_exp_to_replace = np.unique(mem_idx_to_replace[::-1], return_index=True)
+
+				index_exp_to_replace = -index_exp_to_replace +(len(exp_idxs_for_replacing_mem)-1) 
+				index_exp_to_replace = cls_i_exp_subset_idxs[index_exp_to_replace]
+
+				mem_idx_to_retain = cls_i_in_mem_retained_idxs[torch.logical_not(torch.isin(cls_i_in_mem_retained_idxs, torch.tensor(unique_mem_to_replace)))]
+
+				cls_i_exp_subset = AvalancheSubset(experience_data, indices=index_exp_to_replace, paths=experience_data.paths)
+				cls_i_mem_subset = AvalancheSubset(large_classes_retained_dataset, indices=mem_idx_to_retain, paths=large_classes_retained_dataset.paths)
+
+				cls_i_paths = cls_i_exp_subset.paths+cls_i_mem_subset.paths
+				cls_i_data = AvalancheConcatDataset([cls_i_exp_subset, cls_i_mem_subset], paths=cls_i_paths )
+
+				new_memory_paths = new_memory.paths+ cls_i_data.paths
+
+				new_memory = AvalancheConcatDataset([new_memory, cls_i_data],  paths=new_memory_paths)			
+			self.memory_dataset = new_memory
+		all_memory_targets = torch.tensor(self.memory_dataset.targets)
+		strategy.memory_dataset_paths.append(self.memory_dataset.paths)
+		strategy.memory_dataset_targets.append(all_memory_targets)
+		print('full classes: ', self.full_classes)
+		for i in range(strategy.model.num_classes):
+			self.cls_idx_dict[i] = (all_memory_targets==i).nonzero().flatten()
+			self.instances_in_memory_count_dict[i] = self.cls_idx_dict[i].shape[0]
+
+
+
+	def _get_indices(self, strategy):
+		return self.rehearsal_selection_strategy(strategy, self)
+
+

+ 1 - 0
avalanche/avalanche/training/plugins/evaluation_plugin.py

@@ -17,6 +17,7 @@ import numpy as np
 import torch
 import pickle as pkl
 import datetime
+
 class EvaluationPlugin(StrategyPlugin):
     """
     An evaluation plugin that obtains relevant data from the

+ 720 - 0
avalanche/avalanche/training/plugins/imbalance_focus_replay.py

@@ -0,0 +1,720 @@
+import random
+from abc import ABC, abstractmethod
+from typing import Dict, List, Optional, TYPE_CHECKING
+
+import torch
+from numpy import inf
+import numpy as np
+from torch import Tensor, cat
+from torch.nn import Module, CrossEntropyLoss
+from torch.utils.data import random_split, DataLoader
+from torch.optim import Optimizer, SGD, AdamW
+
+from avalanche.benchmarks.utils import AvalancheConcatDataset, \
+	AvalancheDataset, AvalancheSubset
+from avalanche.benchmarks.utils.data_loader import \
+	ReplayDataLoader
+from avalanche.models import FeatureExtractorBackbone
+from avalanche.training.plugins.strategy_plugin import StrategyPlugin
+
+if TYPE_CHECKING:
+	from avalanche.training.strategies import BaseStrategy
+
+from avalanche.models.utils import avalanche_forward
+from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader, _seq_collate_mbatches_fn
+from avalanche.models.dynamic_optimizers import reset_optimizer
+import copy
+import numpy as np
+from scipy import signal
+from avalanche.models.utils import LabelSmoothingCrossEntropy
+
+
+class ClassImbalanceRehersalPlugin(StrategyPlugin):
+	# This is the parent class for all rehearsal strategies implemented below, where the (unrealistic)
+	#  assumpution is that there is unlimited memory space (updated in the class variable cumulative_dataset).
+	# The class variable cls_idx_dict stores the indicies for each class as values with the class integer label
+	# as the key. The buffer_data_ratio variable inidcates how many images are selected from the memory for 
+	# each image in the current expeience. 
+
+	def __init__(self, buffer_data_ratio):
+		super().__init__()
+	
+		self.cumulative_dataset = None
+		self.cls_idx_dict = {}
+		self.buffer_data_ratio = buffer_data_ratio
+
+
+	def train_dataset_adaptation(self, strategy: 'BaseStrategy', **kwargs):
+		
+		indices = None
+		if self.cumulative_dataset is not None:
+
+			strategy.cumulative_dataset_paths = strategy.cumulative_dataset_paths+self.cumulative_dataset._dataset_list[-1].paths
+			indices = self._get_indices(strategy, **kwargs)
+			if indices is None:
+				self.exp_rehersal_set = self.cumulative_dataset # until r experiences have been seen _get_indicies functions return None
+			elif isinstance(indices, list):
+				exp_oversample_indicies = indices[1]
+				indices = indices[0]
+				self.exp_oversample_subset =  AvalancheSubset(strategy.experience.dataset, indices=exp_oversample_indicies)
+				self.cumulative_dataset_subset  = AvalancheSubset(self.cumulative_dataset, indices=indices)
+
+				self.exp_rehersal_set = AvalancheConcatDataset([self.exp_oversample_subset, self.cumulative_dataset_subset], paths=[self.exp_oversample_subset.paths, self.cumulative_dataset_subset.paths])
+			else:
+				self.exp_rehersal_set = AvalancheSubset(self.cumulative_dataset, indices=indices)
+
+			strategy.adapted_dataset = AvalancheConcatDataset([strategy.experience.dataset, self.exp_rehersal_set], paths=[strategy.experience.dataset.paths, self.exp_rehersal_set.paths])
+	
+		else:
+
+			strategy.adapted_dataset = AvalancheConcatDataset([strategy.experience.dataset], paths=strategy.experience.dataset.paths) # in first experience
+
+
+		strategy.rehearsal_indicies_picked[strategy.experience.current_experience] = indices
+
+
+
+
+	def after_training_exp(self, strategy: 'BaseStrategy', **kwargs):
+
+		if self.cumulative_dataset is None:
+				self.cumulative_dataset = strategy.adapted_dataset
+		else:
+			self.cumulative_dataset = AvalancheConcatDataset(
+				[self.cumulative_dataset, strategy.experience.dataset], paths=[self.cumulative_dataset.paths, strategy.experience.dataset.paths])
+
+
+		all_targets = torch.tensor(self.cumulative_dataset.targets)
+
+		for i in range(strategy.model.num_classes):
+			self.cls_idx_dict[i] = (all_targets==i).nonzero().flatten()
+
+
+
+class ClassFrequencyRehearsalPlugin(ClassImbalanceRehersalPlugin):
+
+	# This is the rehearsal strategy is considered the baseline to beat for our problem of rehearsal learning in an imbalanced
+	# dataset situation. It can be considered as the standard oversampling method adapted for rehearsal learning,
+	# The class frequency in the data stream up to the current point in time is used to compute the class inverse frequencies
+	# which are then normalized and used as probabilities to draw from the classes for rehearsal set building. 
+	
+	def __init__(self, buffer_data_ratio):
+		super(ClassFrequencyRehearsalPlugin, self).__init__(buffer_data_ratio)
+
+	def _get_frequency_based_indicies(self, strategy, **kwargs):
+
+		nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
+		if len(self.cumulative_dataset)<= nr_indices_to_return:
+			return None
+
+
+		idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
+		total_data_len = len(self.cumulative_dataset)
+		summed_inv_cls_freq = 0
+		for cls_i, c_idxs  in self.cls_idx_dict.items():
+			if c_idxs.shape[0] != 0:
+				inv_cls_freq = 1/c_idxs.shape[0]
+			else:
+				inv_cls_freq = 0
+			summed_inv_cls_freq += inv_cls_freq
+
+			total_seen_for_cls = len(idx_rehersal_prob[self.cls_idx_dict[cls_i]])
+			if total_seen_for_cls != 0 :
+				idx_rehersal_prob[self.cls_idx_dict[cls_i]] = inv_cls_freq/total_seen_for_cls
+			else:
+				idx_rehersal_prob[self.cls_idx_dict[cls_i]] = 0
+		idx_rehersal_prob = idx_rehersal_prob/summed_inv_cls_freq
+
+		idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
+		return idxs
+
+	def _get_indices(self, strategy):
+		return self._get_frequency_based_indicies(strategy)
+
+class ClassErrorRehersalPlugin(ClassImbalanceRehersalPlugin):
+
+	# This is a new rehearsal strategy where the class accuracies for a validation set are used to estimate 
+	# how well the different classes are predicted. Classes that have a higher error rate, i.e. a smaller class accuracy 
+	# are favoured when selecting images from the memory to build a rehearsal set. More precisely, the lower the class accuracy, 
+	# the higher the probability to draw from instances in the memory of that class.
+	# The error rates for all classes are normalized and used as the probability to draw from that class. The probability mass for
+	# a class is devided among all memory instances of that class and then the multinomial function draws from that distribution without replacement. 
+
+	def __init__(self, buffer_data_ratio):
+		super(ClassErrorRehersalPlugin, self).__init__(buffer_data_ratio)
+
+
+
+	def _get_class_error_based_indicies(self, strategy):
+
+		nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
+		if len(self.cumulative_dataset)<= nr_indices_to_return:
+			return None
+
+		assert strategy.val_cls_acc_dict is not None
+
+		idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
+		for cls_i, cls_acc_i in strategy.val_cls_acc_dict.items():
+			if isinstance(cls_acc_i.result(), dict) and 0 in cls_acc_i.result().keys() : 
+			#print('found class acc')
+				cls_acc = cls_acc_i.result()[0]
+
+				cls_rehersal_prob= 1-cls_acc
+			else: 
+				cls_acc = 0
+			cls_rehersal_prob= 1-cls_acc
+			total_seen_for_cls = len(idx_rehersal_prob[self.cls_idx_dict[cls_i]])
+			if total_seen_for_cls!=0:
+				idx_rehersal_prob[self.cls_idx_dict[cls_i]] = cls_rehersal_prob/total_seen_for_cls
+			else:
+				print('should never be the caser1')
+				idx_rehersal_prob[self.cls_idx_dict[cls_i]] = 0
+
+		idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
+		return idxs
+
+	def _get_indices(self, strategy):
+		return self._get_class_error_based_indicies(strategy)
+
+
+
+
+class ClassErrorRehersalTemperaturePlugin(ClassImbalanceRehersalPlugin):
+
+	# This is a nother extension of the class error based rehearsal method using a temperature parameter of the softmax function to 'sharpen' or 
+	# 'flatten' the class probabilities. The smaller the temperature value, the sharper the distribution. This means images,
+	# that have a high class error rate and are selected for rehearsal with a higher probability with the standard 
+	# Class Error based method are favoured even more whith a small temperature value. 
+
+
+	def __init__(self, buffer_data_ratio, temperature):
+		super(ClassErrorRehersalTemperaturePlugin, self).__init__(buffer_data_ratio)
+		self.temperature = temperature
+
+
+	def _get_class_error_based_indicies(self, strategy):
+
+		nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
+		if len(self.cumulative_dataset)<= nr_indices_to_return:
+			return None
+
+		assert strategy.val_cls_acc_dict is not None
+
+		class_error_rates = torch.ones((len(strategy.val_cls_acc_dict)))
+		for cls_i, cls_acc_i in strategy.val_cls_acc_dict.items():
+			if isinstance(cls_acc_i.result(), dict) and 0 in cls_acc_i.result().keys() : 
+				cls_acc = cls_acc_i.result()[0]
+				cls_rehersal_prob= 1-cls_acc
+			else: # this caseonly occurs when the validation set does not contain all classes.
+				cls_acc = 0
+			class_error_rates[cls_i] = 1-cls_acc
+
+
+		idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
+		for cls_i in strategy.val_cls_acc_dict.keys():
+
+			cls_rehersal_prob= torch.exp(torch.tensor(class_error_rates[cls_i]/self.temperature))
+			total_seen_for_cls = len(self.cls_idx_dict[cls_i])
+			if total_seen_for_cls!=0:
+				idx_rehersal_prob[self.cls_idx_dict[cls_i]] = cls_rehersal_prob/total_seen_for_cls
+			else:
+				idx_rehersal_prob[self.cls_idx_dict[cls_i]] = 0
+
+		idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
+		return idxs
+
+	def _get_indices(self, strategy):
+		return self._get_class_error_based_indicies(strategy)
+
+
+
+class WeightedeMovingClassErrorAverageRehersalPlugin(ClassImbalanceRehersalPlugin):
+
+	# The initial class error based sampling method did not improve upon the inverse frequency baseline. After close inspection, 
+	# we found that the method was simply to instable, the validation accuracy was hugely dependent on the classes occuring in the most recent 
+	# experience. The weighted moving average version of this method stores the validation accuracies from multiple evaluations back and uses a weighted average to 
+	# get a better estimation of what classes are being recognised less. The earlier the evaluation, the smaller the weight will be in that average.
+
+	# sigma defines the shape of the gaussian curve from which the weights are sampled from at equal intervals. the smaller the sigma, the steeper the 
+	# difference between the weights, the less impact the earlier evaluations on the validation set will have on the average. 
+	# nr_of_steps_to_avg defines how many evalutaions of the validation data are used in the averaging. 
+
+
+	def __init__(self, buffer_data_ratio, nr_of_steps_to_avg, nr_classes, sigma):
+		print('Initializing WeightedeMovingClassErrorAverageRehersalPlugin' )
+		super(WeightedeMovingClassErrorAverageRehersalPlugin, self).__init__(buffer_data_ratio)
+		self.nr_of_steps_to_avg = nr_of_steps_to_avg
+		self.nr_classes = nr_classes
+		self.sigma = sigma
+		w = signal.gaussian(self.nr_of_steps_to_avg*2, self.sigma)
+		self.w =torch.tensor(w[-self.nr_of_steps_to_avg:])
+		self.recent_cls_error_rates = torch.zeros(( self.nr_of_steps_to_avg, self.nr_classes))
+		self.started_drawing_from_memory = 0 #used to update recent_cls_error_rates only after 
+
+	def _get_weighted_class_error_avg_based_indicies(self, strategy):
+
+		nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
+		if len(self.cumulative_dataset)<= nr_indices_to_return:
+			return None
+
+		idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
+
+		if strategy.current_train_exp_seen%5==0 or strategy.current_train_exp_seen == self.buffer_data_ratio:
+			assert strategy.val_cls_acc_dict is not None
+
+			idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
+
+			most_recent_error_rates = torch.zeros((1, self.nr_classes))
+			for cls_i, cls_acc_i in strategy.val_cls_acc_dict.items():
+				total_seen_for_cls = len(idx_rehersal_prob[self.cls_idx_dict[cls_i]])
+				if total_seen_for_cls >0:
+					if isinstance(cls_acc_i.result(), dict) and 0 in cls_acc_i.result().keys() : 
+						cls_acc = cls_acc_i.result()[0]
+						cls_error_rate = 1-cls_acc
+					else:
+						cls_error_rate = 1
+				else: 
+					cls_error_rate = 0
+
+				most_recent_error_rates[0, cls_i] = cls_error_rate
+
+			most_recent_error_rates = most_recent_error_rates/torch.sum(most_recent_error_rates)
+			self.recent_cls_error_rates = torch.concat([most_recent_error_rates, self.recent_cls_error_rates])
+			self.recent_cls_error_rates = self.recent_cls_error_rates[:-1, :]
+			self.weighted_recent_cls_error_rates = (self.w*self.recent_cls_error_rates.T).T
+			self.weighted_avg_error_rates = torch.mean(self.weighted_recent_cls_error_rates, dim=0, dtype=torch.float)
+			self.weighted_avg_error_rates = self.weighted_avg_error_rates/torch.sum(self.weighted_recent_cls_error_rates)
+			
+
+		for cls_i in range(self.nr_classes):
+			total_seen_for_cls = len(self.cls_idx_dict[cls_i])
+
+			if total_seen_for_cls!=0:
+				idx_rehersal_prob[self.cls_idx_dict[cls_i]] = self.weighted_avg_error_rates[cls_i]/total_seen_for_cls
+
+
+		idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
+		return idxs
+
+	def _get_indices(self, strategy):
+		return self._get_weighted_class_error_avg_based_indicies(strategy)
+
+
+class ClassErrorFrequencyAvgRehearsalPlugin(ClassImbalanceRehersalPlugin):
+
+	# This Rehearsal plugin is a hybrid method between the standar class inveres frequency methd and the class error based method. 
+	# The class inverse frequencies and class error rates are each normalized across all classes and then simply averaged 
+	# and used to assign probabilities to the classes of being picked for rehearsal sampling. 
+
+	def __init__(self, buffer_data_ratio, nr_classes):
+		super(ClassErrorFrequencyAvgRehearsalPlugin, self).__init__(buffer_data_ratio)
+		self.nr_classes = nr_classes
+
+	def _get_frequency_based_indicies(self, strategy, **kwargs):
+
+		nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
+		if len(self.cumulative_dataset)<= nr_indices_to_return:
+			return None
+
+		assert strategy.val_cls_acc_dict is not None
+
+		idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
+		total_data_len = len(self.cumulative_dataset)
+		most_recent_error_rates = torch.zeros((1, self.nr_classes))
+		cls_inv_freq = torch.zeros((1, self.nr_classes))
+		for cls_i, cls_acc_i in strategy.val_cls_acc_dict.items():
+			total_seen_for_cls = len(idx_rehersal_prob[self.cls_idx_dict[cls_i]])
+			if total_seen_for_cls >0:
+				if isinstance(cls_acc_i.result(), dict) and 0 in cls_acc_i.result().keys() : 
+					cls_acc = cls_acc_i.result()[0]
+
+					cls_error_rate = 1-cls_acc
+				else: # this caseonly occurs when the validation set does not contain all classes.
+					cls_error_rate = 1
+			else: 
+				cls_error_rate = 0
+			most_recent_error_rates[0,cls_i] = cls_error_rate
+		for cls_i, c_idxs  in self.cls_idx_dict.items():
+			if self.cls_idx_dict[cls_i].shape[0] != 0:
+				inv_freq = 1/c_idxs.shape[0]
+			else:
+				inv_freq = 0
+
+			cls_inv_freq[0,cls_i] = inv_freq
+			
+
+		most_recent_error_rates = most_recent_error_rates/torch.sum(most_recent_error_rates)
+		cls_inv_freq = cls_inv_freq/torch.sum(cls_inv_freq)
+
+		elementwise_sum = most_recent_error_rates+cls_inv_freq
+		
+		averaged_weights = elementwise_sum/torch.sum(elementwise_sum)
+
+		for cls_i in range(self.nr_classes):
+			total_seen_for_cls = len(self.cls_idx_dict[cls_i])
+			if total_seen_for_cls != 0 :
+				idx_rehersal_prob[self.cls_idx_dict[cls_i]] = averaged_weights[0, cls_i]/total_seen_for_cls
+		
+
+		idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
+		return idxs
+
+	def _get_indices(self, strategy):
+		return self._get_frequency_based_indicies(strategy)
+
+
+class FillExpBasedRehearsalPlugin(ClassImbalanceRehersalPlugin):
+	# With this rehearsal strategy the number of images in the finetune set (experience data plus selected memory data, i.e. rehearsal set)
+	# should be roughly the same for each class. This can only roughly be the case, because sometimes there is a class in the
+	# current experience, that is bigger than the total number of images per class should be if the finetune set is devided equally. 
+	# As oppose tho the FillExpOversampleRehearsalPlugin, this Plugin does not use oversampling of memory or experience data to achieve 
+	# the finetuning set that is as close as possible to a balanced dataset as possible. Here the 'empty' slots that can not be filled beccause a class
+	# does not have enough images in the memory or current experience are filled by drawing randomly from the memeory data that is left after the 
+	# classes have been used to fill the rehearsal set. This has the effect that the finetuneing set is as balanced as possible when constraining the rehearsal 
+	# set to unique instances. 
+	#  
+	# First, the number of images that should be in the finetune set of each class if it was perfectly balanced is calculated (img_per_cls)
+	# based on the total number of unique classes in the memory (cls_idx_dict) and the unique classes in the current experience.
+	# All classes unique classes in memory and current experience are unified in the seen_cls list. 
+	# The classes from that list, that are 'full', i.e. have more images than img_per_cls, are removed to form the non_full_classes list.
+	# This list is the list of classes the rehearsal set will have. 
+	# The img_per_cls number is updated to consider the impact the classes in the experience have, that have more instances than the optimal, 
+	# perfectly devided finetune set.
+
+	# For each of the classes in not_full_cls the data from the memory for that class is sampled without replacement. If there are not enough 
+	# instances in the memory for that class, the number of missing instances are accumilated in nr_remaining. 
+	# At the end, the memory instances that fave not been selected yet, are used to sample nr_remaing ixs at random. 
+	def __init__(self, buffer_data_ratio, nr_classes):
+		super(FillExpBasedRehearsalPlugin, self).__init__(buffer_data_ratio)
+		self.nr_classes = nr_classes
+
+	def _get_experience_based_indicies(self, strategy, **kwargs):
+
+		nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
+		if len(self.cumulative_dataset)<= nr_indices_to_return:
+			return None
+
+		targets_in_exp = torch.tensor(strategy.experience.dataset.targets)
+		unique, counts = torch.unique(targets_in_exp, return_counts=True)
+		idxs = torch.tensor([])
+		seen_cls = []
+		for cls_i, c_idxs in  self.cls_idx_dict.items():
+			if (c_idxs.shape[0] == 0 and cls_i in unique) or c_idxs.shape[0] > 0  :
+				seen_cls.append(cls_i)
+		exp_counts = {}
+		for i, u in enumerate(unique):
+		 	exp_counts[u.item()]= counts[i]
+
+		img_per_cls =int((nr_indices_to_return+targets_in_exp.shape[0])/len(seen_cls))
+		counts_over = counts[counts>=img_per_cls]
+		counts_over_total = torch.sum(counts_over)
+		
+		nr_not_full_classes =len(seen_cls)- torch.sum(counts>=img_per_cls)
+		seen_cls = torch.tensor(seen_cls)
+		not_full_classes = seen_cls[torch.logical_not(torch.isin(seen_cls, unique[counts>=img_per_cls]))]
+
+		img_per_cls =int( img_per_cls-(counts_over_total/nr_not_full_classes))
+		img_per_cls =int(torch.floor((nr_indices_to_return+targets_in_exp.shape[0]-counts_over_total)/len(not_full_classes)))
+
+		nr_remaining = 0
+		for cls_i, c_idxs in  self.cls_idx_dict.items():
+
+			if cls_i in not_full_classes:
+				if cls_i in unique: # class in experience
+					
+					count_cls = exp_counts[cls_i]
+		
+					img_left_to_draw_for_cls_i =img_per_cls-count_cls
+				else: # one of the already seen classes not in current experience
+					img_left_to_draw_for_cls_i = img_per_cls
+
+				if img_left_to_draw_for_cls_i <=0: # occurs when the updated img_per_cls is smaller than some of the number of instances in the experience
+					cls__i_idxs = torch.tensor([])
+				elif c_idxs.shape[0] > img_left_to_draw_for_cls_i:
+					indices = torch.randperm(c_idxs.shape[0])[:img_left_to_draw_for_cls_i]
+					cls__i_idxs = c_idxs[indices]
+				elif c_idxs.shape[0]>0 and c_idxs.shape[0]<= img_left_to_draw_for_cls_i:
+					cls__i_idxs = c_idxs
+					nr_remaining += (img_left_to_draw_for_cls_i-c_idxs.shape[0])
+				else: # occures when there is a new class, that has not been added to the memory
+					cls__i_idxs = torch.tensor([])
+					nr_remaining += img_left_to_draw_for_cls_i
+
+				idxs = torch.concat([idxs, cls__i_idxs])
+
+		idxs = torch.tensor(idxs, dtype=int)
+		nr_idxs_to_draw_randomly = nr_indices_to_return-idxs.shape[0]
+		if nr_idxs_to_draw_randomly >0:
+
+			cumulative_dataset_idxs = torch.arange(0, len(self.cumulative_dataset))
+			cumulative_dataset_idxs_remaining = cumulative_dataset_idxs[torch.logical_not(torch.isin(cumulative_dataset_idxs, idxs))]
+			pos_picked = torch.randperm(len(cumulative_dataset_idxs_remaining))[:nr_idxs_to_draw_randomly]
+			randomly_picked_idxs = cumulative_dataset_idxs_remaining[pos_picked]
+			idxs = torch.concat([idxs, randomly_picked_idxs])
+		idxs = torch.tensor(idxs, dtype=int)
+		if idxs.shape[0]!= nr_indices_to_return:
+			print('ATTENTION: idxs does not have the correct length')
+			print(len(idxs))
+			print(nr_indices_to_return)
+		return idxs
+
+	def _get_indices(self, strategy):
+		return self._get_experience_based_indicies(strategy)
+
+
+class FillExpOversampleBasedRehearsalPlugin(ClassImbalanceRehersalPlugin):
+	# With this rehearsal strategy the number of images in the finetune set (experience data plus selected memory data, i.e. rehearsal set)
+	# should be roughly the same for each class. This can only roughly be the case, because sometimes there is a class in the
+	# current experience, that is bigger than the total number of images per class should be if the finetune set is devided equally. 
+	# 
+	# First, the number of images that should be in the finetune set of each class if it was perfectly balanced is calculated (img_per_cls)
+	# based on the total number of unique classes in the memory (cls_idx_dict) and the unique classes in the current experience.
+	# All classes unique classes in memory and current experience are unified in the seen_cls list. 
+	# The classes from that list, that are 'full', i.e. have more images than img_per_cls, are removed to form the non_full_classes list.
+	# This list is the list of classes the rehearsal set will have. 
+	# The img_per_cls number is updated to consider the impact the classes in the experience have, that have mor instances than the optimal, 
+	# perfectly devided finetune set.
+	# In FillExpOversample the classes will be filled by sampling or oversampling
+	# from the memory (cls_idx_dict). The indecies of the memory used for the rehearsal set arecollected in the idxs tensor.
+	# In istances when the current experience has a new class that had not occured previously in the stream, the experience itself 
+	# is oversampled to ensure the finetune-set has the best possible class balance. This is implemented with the exp_oversample_idxs tensor.
+
+	# Both tensors are returned and the train_dataset_adaptation function in the mother class ClassImbalanceRehearsalPlugin handles the them
+	# to build a rehearsal set. 
+
+	def __init__(self, buffer_data_ratio, nr_classes):
+		super(FillExpOversampleBasedRehearsalPlugin, self).__init__(buffer_data_ratio)
+		self.nr_classes = nr_classes
+
+	def _get_experience_based_indicies(self, strategy, **kwargs):
+
+		nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
+		if len(self.cumulative_dataset)<= nr_indices_to_return:
+			return None
+
+		targets_in_exp = torch.tensor(strategy.experience.dataset.targets)
+		unique, counts = torch.unique(targets_in_exp, return_counts=True)
+		idxs = torch.tensor([])
+		exp_oversample_idxs = torch.tensor([], )
+		seen_cls = []
+		for cls_i, c_idxs in  self.cls_idx_dict.items():
+			if (c_idxs.shape[0] == 0 and cls_i in unique) or c_idxs.shape[0] > 0  :
+				seen_cls.append(cls_i)
+		exp_counts = {}
+		for i, u in enumerate(unique):
+		 	exp_counts[u.item()]= counts[i]
+
+		img_per_cls =int((nr_indices_to_return+targets_in_exp.shape[0])/len(seen_cls))
+		counts_over = counts[counts>=img_per_cls]
+		counts_over_total = torch.sum(counts_over)
+		
+		nr_not_full_classes =len(seen_cls)- torch.sum(counts>=img_per_cls)
+		seen_cls = torch.tensor(seen_cls)
+		not_full_classes = seen_cls[torch.logical_not(torch.isin(seen_cls, unique[counts>=img_per_cls]))]
+
+		img_per_cls =int( img_per_cls-(counts_over_total/nr_not_full_classes))
+		img_per_cls =int(torch.ceil((nr_indices_to_return+targets_in_exp.shape[0]-counts_over_total)/len(not_full_classes)))
+
+		nr_remaining = 0
+		for cls_i, c_idxs in  self.cls_idx_dict.items():
+
+			if cls_i in not_full_classes:
+				if cls_i in unique: # if class is in experience 
+					
+					count_cls = exp_counts[cls_i]
+		
+					img_left_to_draw_for_cls_i =img_per_cls-count_cls
+				else:
+					img_left_to_draw_for_cls_i = img_per_cls
+
+				
+				if img_left_to_draw_for_cls_i<=0: # occurs when the updated img_per_cls is smaller than some of the number of instances in the experience
+					cls__i_idxs=torch.tensor([])
+				elif c_idxs.shape[0] > img_left_to_draw_for_cls_i:  
+					indices = torch.randperm(c_idxs.shape[0])[:img_left_to_draw_for_cls_i] # randomly select memory images from the class
+					cls__i_idxs = c_idxs[indices]
+				elif c_idxs.shape[0]>0 and c_idxs.shape[0]<= img_left_to_draw_for_cls_i: # 
+					c_idxs_pos_idx = torch.multinomial(torch.ones((c_idxs.shape[0], )), img_left_to_draw_for_cls_i, replacement=True) # draw with replacement to fill the number of idxs
+					cls__i_idxs = c_idxs[c_idxs_pos_idx]
+					
+				else: # this case applies when the class appears for the first time, meaning there are no images in the memory yet
+
+					cls_i_exp_idxs = torch.where(targets_in_exp==cls_i)[0]
+					cls_i_exp_pos_idx = torch.multinomial(torch.ones((cls_i_exp_idxs.shape[0], )), img_left_to_draw_for_cls_i, replacement=True)
+					exp_oversample_idxs = torch.concat([exp_oversample_idxs, cls_i_exp_idxs[cls_i_exp_pos_idx]])
+					cls__i_idxs = torch.tensor([])
+					
+
+				idxs = torch.concat([idxs, cls__i_idxs])
+
+		idxs = torch.tensor(idxs, dtype=int)
+		exp_oversample_idxs = torch.tensor(exp_oversample_idxs, dtype=int)
+		
+		idxs_selected_for_return = idxs.shape[0]+exp_oversample_idxs.shape[0]
+		nr_idxs_to_draw_randomly = nr_indices_to_return-idxs_selected_for_return
+
+		# handling edge casers where total indicies returned is larger than exact number of indicies needed in rehearsal set
+		# This occures due to rounding errors.
+		if idxs_selected_for_return> nr_indices_to_return:
+
+			diff = idxs_selected_for_return-nr_indices_to_return
+			if len(idxs)>len(exp_oversample_idxs):
+				idxs_drop_set = torch.randperm(idxs.shape[0])[:(idxs.shape[0]-diff)]
+				idxs= idxs[idxs_drop_set]
+			else: 
+				idxs_drop_set = torch.randperm(exp_oversample_idxs.shape[0])[:(exp_oversample_idxs.shape[0]-diff)]
+				exp_oversample_idxs[idxs_drop_set]
+
+			idxs_selected_for_return = idxs.shape[0]+exp_oversample_idxs.shape[0]
+
+
+		return [idxs, exp_oversample_idxs]
+
+	def _get_indices(self, strategy):
+		return self._get_experience_based_indicies(strategy)
+
+
+
+
+
+class RandomRehersal(ClassImbalanceRehersalPlugin):
+	# This rehearsal strategy simply draws uzniformly from the memory for reahearsal.
+	# No class specific information is used. 
+
+	def __init__(self, buffer_data_ratio):
+		super(RandomRehersal, self).__init__(buffer_data_ratio)
+
+	def _get_class_random_indicies(self, strategy):
+		nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
+		if len(self.cumulative_dataset)<= nr_indices_to_return:
+			return None
+		else:
+			idxs = random.sample(range(len(self.cumulative_dataset)), nr_indices_to_return)
+			idxs = torch.tensor(idxs)
+
+			return idxs
+
+	def _get_indices(self, strategy):
+		return self._get_class_random_indicies(strategy)
+
+
+class MaximallyInterferedRetrievalRehersalPlugin(ClassImbalanceRehersalPlugin):
+
+	# This is a rehearsal strategy based on a paper called 'Online Continual Learning with Maximally Interfered Retrieval' by Aljundi et. al.
+	# Their idea was to select images from the memory, where the loss increases most, if the model was finetuned only with the current experience.
+	# A randomly selected subset of the memory is evaluated with the current model, then a virtual update step is performed by finetuning on the 
+	# current experience, then the new model is again used to evaluate the memory subset. The instances where the impact on the loss was the 
+	# largest are selected for the rehearsal set. The rehearsal set and the current experience together are then used to finetune the model 
+	# (without virtual update) and perform the actual.
+
+	# This method was not specifically targeted to handle class imbalance data. However it is one of the few methods that target rehearsal learning
+	# by strategically selecting images from the memory, which is why it was used to compare our methods. 
+
+
+	def __init__(self, buffer_data_ratio, candidate_subset_size=10):
+		self.eval_criterion = LabelSmoothingCrossEntropy(reduction='none')
+		self.criterion = LabelSmoothingCrossEntropy()
+		self.candidate_subset_size = candidate_subset_size
+		super(MaximallyInterferedRetrievalRehersalPlugin, self).__init__(buffer_data_ratio=buffer_data_ratio)
+		
+
+	def _unpack_minibatch_MIR(self, mbatch, strategy):
+
+		paths = mbatch[1][0]
+		mbatch = mbatch[0]
+
+		assert len(mbatch) >= 3
+
+		input_imgs = mbatch[0].to(strategy.device)
+		targets = mbatch[1].to(strategy.device)
+		return paths, input_imgs, targets
+
+	def _get_MIR_based_indicies(self, strategy):
+		logging_mir_dict ={}
+		nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
+		if len(self.cumulative_dataset)<= nr_indices_to_return:
+			return None
+		candidate_subset_idxs = sorted(random.sample(range(len(self.cumulative_dataset)), min(nr_indices_to_return,len(self.cumulative_dataset))))
+		candidate_subset = AvalancheSubset(self.cumulative_dataset,  candidate_subset_idxs)
+		original_model_weights = copy.deepcopy(strategy.model.state_dict())
+		logging_mir_dict['candidate_subset_idxs'] = candidate_subset_idxs
+
+		eps=1e-4
+		weight_decay=75e-3
+		self.optimizer = AdamW(strategy.model.parameters(), eps=eps, weight_decay=weight_decay)
+		data_loader = TaskBalancedDataLoader(
+				candidate_subset,
+				collate_mbatches=_seq_collate_mbatches_fn,
+				batch_size=strategy.train_mb_size,
+				shuffle=False,
+				pin_memory=False)
+		strategy.model.eval()
+		all_losses_pre_update = []
+		all_paths = []
+		with torch.no_grad():
+			for  mbatch in data_loader:
+
+				paths, input_imgs, targets = self._unpack_minibatch_MIR(mbatch, strategy)
+
+				all_paths = all_paths+ np.array(paths).tolist()
+				mb_output = strategy.model(input_imgs)
+				pre_update_loss = self.eval_criterion(mb_output, targets)
+				all_losses_pre_update = all_losses_pre_update+ pre_update_loss.cpu().detach().numpy().tolist()
+
+		logging_mir_dict['all_losses_pre_update'] = all_losses_pre_update
+
+		# virtual update step 
+		experience_data_loader = TaskBalancedDataLoader(
+				strategy.experience.dataset,
+				collate_mbatches=_seq_collate_mbatches_fn,
+				batch_size=strategy.train_mb_size,
+				shuffle=True,
+				pin_memory=False)
+
+		strategy.model.train()
+		reset_optimizer(self.optimizer, strategy.model)
+
+		for   mbatch in experience_data_loader:
+			_, input_imgs, targets = self._unpack_minibatch_MIR(mbatch, strategy)
+			self.optimizer.zero_grad()
+			mb_output = strategy.model(input_imgs)
+			loss = self.criterion(mb_output, targets)
+			loss.backward()
+			self.optimizer.step()
+
+		data_loader = TaskBalancedDataLoader(
+				candidate_subset,
+				collate_mbatches=_seq_collate_mbatches_fn,
+				batch_size=strategy.train_mb_size,
+				shuffle=False,
+				pin_memory=False)
+		strategy.model.eval()
+		all_losses_after_update = []
+		with torch.no_grad():
+			for  mbatch in data_loader:
+
+				paths, input_imgs, targets = self._unpack_minibatch_MIR(mbatch, strategy)
+				mb_output = strategy.model(input_imgs)
+
+				after_update_loss = self.eval_criterion(mb_output, targets)
+
+				all_losses_after_update = all_losses_after_update + after_update_loss.cpu().detach().numpy().tolist()
+		logging_mir_dict['all_losses_after_update'] = all_losses_after_update
+
+		differences = (np.array(all_losses_pre_update) -np.array(all_losses_after_update)).tolist()
+
+		from_candidates_idxs =sorted(range(len(differences)), key=differences.__getitem__, reverse=True)
+
+		idx_ordering = np.argsort(-np.array(differences))
+		reordered_paths = np.array(all_paths)[idx_ordering]
+
+		strategy.mir_losses_dict[strategy.experience.current_experience] = logging_mir_dict
+		candidate_subset_idxs = np.array(candidate_subset_idxs)
+		idxs = candidate_subset_idxs[from_candidates_idxs][:nr_indices_to_return]
+		strategy.model.load_state_dict(original_model_weights)
+
+		return idxs
+
+
+	def _get_indices(self, strategy):
+		return self._get_MIR_based_indicies(strategy)

+ 29 - 0
avalanche/avalanche/training/plugins/label_smoothing_loss.py

@@ -0,0 +1,29 @@
+from torch import nn
+from torch.nn import functional as F
+
+def linear_combination(x, y, epsilon):
+    return epsilon * x + (1 - epsilon) * y
+
+def reduce_loss(loss, reduction="mean"):
+    if reduction == "mean":
+        return loss.mean()
+    
+    elif reduction == "sum":
+        return loss.sum()
+    
+    else:
+        return loss
+    
+class LabelSmoothingCrossEntropy(nn.Module):
+    def __init__(self, epsilon: float = 0.1, reduction="mean"):
+        super(LabelSmoothingCrossEntropy, self).__init__()
+        self.epsilon = epsilon
+        self.reduction = reduction
+    def forward(self, preds, target):
+        n_classes = preds.size()[-1]
+        
+        log_preds = F.log_softmax(preds, dim=-1)
+        loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
+        nll = F.nll_loss(log_preds, target, reduction=self.reduction)
+        
+        return linear_combination(loss / n_classes, nll, self.epsilon)

+ 7 - 1
avalanche/avalanche/training/strategies/base_strategy.py

@@ -31,6 +31,7 @@ if TYPE_CHECKING:
     from avalanche.core import StrategyCallbacks
     from avalanche.training.plugins import StrategyPlugin
 from pathlib import Path
+import numpy as np
 
 logger = logging.getLogger(__name__)
 
@@ -218,7 +219,12 @@ class BaseStrategy:
 
         self.label_dict = label_dict 
         #""" Dictionary with int-labels as keys and class names as values """
-
+        self.MIR_selected_imgs = np.array([])
+        self.mir_losses_dict = {}
+        self.cumulative_dataset_paths =[]
+        self.memory_dataset_paths =[]
+        self.memory_dataset_targets=[]
+        self.rehearsal_indicies_picked = {}
 
     @property
     def is_eval(self):

+ 5 - 3
avalanche/avalanche/training/strategies/strategy_wrappers.py

@@ -37,10 +37,11 @@ class Naive(BaseStrategy):
                  train_mb_size: int = 1, train_epochs: int = 1,
                  eval_mb_size: int = None, device=None,
                  plugins: Optional[List[StrategyPlugin]] = None,
-                 evaluator: EvaluationPlugin = default_logger, eval_every=-1):
+                 evaluator: EvaluationPlugin = default_logger, eval_every=-1, label_dict=None):
+
         """
         Creates an instance of the Naive strategy.
-
+        
         :param model: The model.
         :param optimizer: The optimizer to use.
         :param criterion: The loss criterion to use.
@@ -59,11 +60,12 @@ class Naive(BaseStrategy):
                 if >0: calls `eval` every `eval_every` epochs and at the end
                     of all the epochs for a single experience.
         """
+
         super().__init__(
             model, optimizer, criterion,
             train_mb_size=train_mb_size, train_epochs=train_epochs,
             eval_mb_size=eval_mb_size, device=device, plugins=plugins,
-            evaluator=evaluator, eval_every=eval_every)
+            evaluator=evaluator, eval_every=eval_every, label_dict=label_dict)
 
 
 class PNNStrategy(BaseStrategy):

+ 3 - 1
avalanche/avalanche/training/utils.py

@@ -279,6 +279,7 @@ def examples_per_class(targets):
     return result
 
 
+
 __all__ = [
     'load_all_dataset',
     'zerolike_params_dict',
@@ -294,5 +295,6 @@ __all__ = [
     'freeze_everything',
     'unfreeze_everything',
     'freeze_up_to',
-    'examples_per_class'
+    'examples_per_class',
+
 ]