|
@@ -0,0 +1,720 @@
|
|
|
+import random
|
|
|
+from abc import ABC, abstractmethod
|
|
|
+from typing import Dict, List, Optional, TYPE_CHECKING
|
|
|
+
|
|
|
+import torch
|
|
|
+from numpy import inf
|
|
|
+import numpy as np
|
|
|
+from torch import Tensor, cat
|
|
|
+from torch.nn import Module, CrossEntropyLoss
|
|
|
+from torch.utils.data import random_split, DataLoader
|
|
|
+from torch.optim import Optimizer, SGD, AdamW
|
|
|
+
|
|
|
+from avalanche.benchmarks.utils import AvalancheConcatDataset, \
|
|
|
+ AvalancheDataset, AvalancheSubset
|
|
|
+from avalanche.benchmarks.utils.data_loader import \
|
|
|
+ ReplayDataLoader
|
|
|
+from avalanche.models import FeatureExtractorBackbone
|
|
|
+from avalanche.training.plugins.strategy_plugin import StrategyPlugin
|
|
|
+
|
|
|
+if TYPE_CHECKING:
|
|
|
+ from avalanche.training.strategies import BaseStrategy
|
|
|
+
|
|
|
+from avalanche.models.utils import avalanche_forward
|
|
|
+from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader, _seq_collate_mbatches_fn
|
|
|
+from avalanche.models.dynamic_optimizers import reset_optimizer
|
|
|
+import copy
|
|
|
+import numpy as np
|
|
|
+from scipy import signal
|
|
|
+from avalanche.models.utils import LabelSmoothingCrossEntropy
|
|
|
+
|
|
|
+
|
|
|
+class ClassImbalanceRehersalPlugin(StrategyPlugin):
|
|
|
+ # This is the parent class for all rehearsal strategies implemented below, where the (unrealistic)
|
|
|
+ # assumpution is that there is unlimited memory space (updated in the class variable cumulative_dataset).
|
|
|
+ # The class variable cls_idx_dict stores the indicies for each class as values with the class integer label
|
|
|
+ # as the key. The buffer_data_ratio variable inidcates how many images are selected from the memory for
|
|
|
+ # each image in the current expeience.
|
|
|
+
|
|
|
+ def __init__(self, buffer_data_ratio):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.cumulative_dataset = None
|
|
|
+ self.cls_idx_dict = {}
|
|
|
+ self.buffer_data_ratio = buffer_data_ratio
|
|
|
+
|
|
|
+
|
|
|
+ def train_dataset_adaptation(self, strategy: 'BaseStrategy', **kwargs):
|
|
|
+
|
|
|
+ indices = None
|
|
|
+ if self.cumulative_dataset is not None:
|
|
|
+
|
|
|
+ strategy.cumulative_dataset_paths = strategy.cumulative_dataset_paths+self.cumulative_dataset._dataset_list[-1].paths
|
|
|
+ indices = self._get_indices(strategy, **kwargs)
|
|
|
+ if indices is None:
|
|
|
+ self.exp_rehersal_set = self.cumulative_dataset # until r experiences have been seen _get_indicies functions return None
|
|
|
+ elif isinstance(indices, list):
|
|
|
+ exp_oversample_indicies = indices[1]
|
|
|
+ indices = indices[0]
|
|
|
+ self.exp_oversample_subset = AvalancheSubset(strategy.experience.dataset, indices=exp_oversample_indicies)
|
|
|
+ self.cumulative_dataset_subset = AvalancheSubset(self.cumulative_dataset, indices=indices)
|
|
|
+
|
|
|
+ self.exp_rehersal_set = AvalancheConcatDataset([self.exp_oversample_subset, self.cumulative_dataset_subset], paths=[self.exp_oversample_subset.paths, self.cumulative_dataset_subset.paths])
|
|
|
+ else:
|
|
|
+ self.exp_rehersal_set = AvalancheSubset(self.cumulative_dataset, indices=indices)
|
|
|
+
|
|
|
+ strategy.adapted_dataset = AvalancheConcatDataset([strategy.experience.dataset, self.exp_rehersal_set], paths=[strategy.experience.dataset.paths, self.exp_rehersal_set.paths])
|
|
|
+
|
|
|
+ else:
|
|
|
+
|
|
|
+ strategy.adapted_dataset = AvalancheConcatDataset([strategy.experience.dataset], paths=strategy.experience.dataset.paths) # in first experience
|
|
|
+
|
|
|
+
|
|
|
+ strategy.rehearsal_indicies_picked[strategy.experience.current_experience] = indices
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def after_training_exp(self, strategy: 'BaseStrategy', **kwargs):
|
|
|
+
|
|
|
+ if self.cumulative_dataset is None:
|
|
|
+ self.cumulative_dataset = strategy.adapted_dataset
|
|
|
+ else:
|
|
|
+ self.cumulative_dataset = AvalancheConcatDataset(
|
|
|
+ [self.cumulative_dataset, strategy.experience.dataset], paths=[self.cumulative_dataset.paths, strategy.experience.dataset.paths])
|
|
|
+
|
|
|
+
|
|
|
+ all_targets = torch.tensor(self.cumulative_dataset.targets)
|
|
|
+
|
|
|
+ for i in range(strategy.model.num_classes):
|
|
|
+ self.cls_idx_dict[i] = (all_targets==i).nonzero().flatten()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+class ClassFrequencyRehearsalPlugin(ClassImbalanceRehersalPlugin):
|
|
|
+
|
|
|
+ # This is the rehearsal strategy is considered the baseline to beat for our problem of rehearsal learning in an imbalanced
|
|
|
+ # dataset situation. It can be considered as the standard oversampling method adapted for rehearsal learning,
|
|
|
+ # The class frequency in the data stream up to the current point in time is used to compute the class inverse frequencies
|
|
|
+ # which are then normalized and used as probabilities to draw from the classes for rehearsal set building.
|
|
|
+
|
|
|
+ def __init__(self, buffer_data_ratio):
|
|
|
+ super(ClassFrequencyRehearsalPlugin, self).__init__(buffer_data_ratio)
|
|
|
+
|
|
|
+ def _get_frequency_based_indicies(self, strategy, **kwargs):
|
|
|
+
|
|
|
+ nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
|
|
|
+ if len(self.cumulative_dataset)<= nr_indices_to_return:
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+ idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
|
|
|
+ total_data_len = len(self.cumulative_dataset)
|
|
|
+ summed_inv_cls_freq = 0
|
|
|
+ for cls_i, c_idxs in self.cls_idx_dict.items():
|
|
|
+ if c_idxs.shape[0] != 0:
|
|
|
+ inv_cls_freq = 1/c_idxs.shape[0]
|
|
|
+ else:
|
|
|
+ inv_cls_freq = 0
|
|
|
+ summed_inv_cls_freq += inv_cls_freq
|
|
|
+
|
|
|
+ total_seen_for_cls = len(idx_rehersal_prob[self.cls_idx_dict[cls_i]])
|
|
|
+ if total_seen_for_cls != 0 :
|
|
|
+ idx_rehersal_prob[self.cls_idx_dict[cls_i]] = inv_cls_freq/total_seen_for_cls
|
|
|
+ else:
|
|
|
+ idx_rehersal_prob[self.cls_idx_dict[cls_i]] = 0
|
|
|
+ idx_rehersal_prob = idx_rehersal_prob/summed_inv_cls_freq
|
|
|
+
|
|
|
+ idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
|
|
|
+ return idxs
|
|
|
+
|
|
|
+ def _get_indices(self, strategy):
|
|
|
+ return self._get_frequency_based_indicies(strategy)
|
|
|
+
|
|
|
+class ClassErrorRehersalPlugin(ClassImbalanceRehersalPlugin):
|
|
|
+
|
|
|
+ # This is a new rehearsal strategy where the class accuracies for a validation set are used to estimate
|
|
|
+ # how well the different classes are predicted. Classes that have a higher error rate, i.e. a smaller class accuracy
|
|
|
+ # are favoured when selecting images from the memory to build a rehearsal set. More precisely, the lower the class accuracy,
|
|
|
+ # the higher the probability to draw from instances in the memory of that class.
|
|
|
+ # The error rates for all classes are normalized and used as the probability to draw from that class. The probability mass for
|
|
|
+ # a class is devided among all memory instances of that class and then the multinomial function draws from that distribution without replacement.
|
|
|
+
|
|
|
+ def __init__(self, buffer_data_ratio):
|
|
|
+ super(ClassErrorRehersalPlugin, self).__init__(buffer_data_ratio)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def _get_class_error_based_indicies(self, strategy):
|
|
|
+
|
|
|
+ nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
|
|
|
+ if len(self.cumulative_dataset)<= nr_indices_to_return:
|
|
|
+ return None
|
|
|
+
|
|
|
+ assert strategy.val_cls_acc_dict is not None
|
|
|
+
|
|
|
+ idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
|
|
|
+ for cls_i, cls_acc_i in strategy.val_cls_acc_dict.items():
|
|
|
+ if isinstance(cls_acc_i.result(), dict) and 0 in cls_acc_i.result().keys() :
|
|
|
+ #print('found class acc')
|
|
|
+ cls_acc = cls_acc_i.result()[0]
|
|
|
+
|
|
|
+ cls_rehersal_prob= 1-cls_acc
|
|
|
+ else:
|
|
|
+ cls_acc = 0
|
|
|
+ cls_rehersal_prob= 1-cls_acc
|
|
|
+ total_seen_for_cls = len(idx_rehersal_prob[self.cls_idx_dict[cls_i]])
|
|
|
+ if total_seen_for_cls!=0:
|
|
|
+ idx_rehersal_prob[self.cls_idx_dict[cls_i]] = cls_rehersal_prob/total_seen_for_cls
|
|
|
+ else:
|
|
|
+ print('should never be the caser1')
|
|
|
+ idx_rehersal_prob[self.cls_idx_dict[cls_i]] = 0
|
|
|
+
|
|
|
+ idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
|
|
|
+ return idxs
|
|
|
+
|
|
|
+ def _get_indices(self, strategy):
|
|
|
+ return self._get_class_error_based_indicies(strategy)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+class ClassErrorRehersalTemperaturePlugin(ClassImbalanceRehersalPlugin):
|
|
|
+
|
|
|
+ # This is a nother extension of the class error based rehearsal method using a temperature parameter of the softmax function to 'sharpen' or
|
|
|
+ # 'flatten' the class probabilities. The smaller the temperature value, the sharper the distribution. This means images,
|
|
|
+ # that have a high class error rate and are selected for rehearsal with a higher probability with the standard
|
|
|
+ # Class Error based method are favoured even more whith a small temperature value.
|
|
|
+
|
|
|
+
|
|
|
+ def __init__(self, buffer_data_ratio, temperature):
|
|
|
+ super(ClassErrorRehersalTemperaturePlugin, self).__init__(buffer_data_ratio)
|
|
|
+ self.temperature = temperature
|
|
|
+
|
|
|
+
|
|
|
+ def _get_class_error_based_indicies(self, strategy):
|
|
|
+
|
|
|
+ nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
|
|
|
+ if len(self.cumulative_dataset)<= nr_indices_to_return:
|
|
|
+ return None
|
|
|
+
|
|
|
+ assert strategy.val_cls_acc_dict is not None
|
|
|
+
|
|
|
+ class_error_rates = torch.ones((len(strategy.val_cls_acc_dict)))
|
|
|
+ for cls_i, cls_acc_i in strategy.val_cls_acc_dict.items():
|
|
|
+ if isinstance(cls_acc_i.result(), dict) and 0 in cls_acc_i.result().keys() :
|
|
|
+ cls_acc = cls_acc_i.result()[0]
|
|
|
+ cls_rehersal_prob= 1-cls_acc
|
|
|
+ else: # this caseonly occurs when the validation set does not contain all classes.
|
|
|
+ cls_acc = 0
|
|
|
+ class_error_rates[cls_i] = 1-cls_acc
|
|
|
+
|
|
|
+
|
|
|
+ idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
|
|
|
+ for cls_i in strategy.val_cls_acc_dict.keys():
|
|
|
+
|
|
|
+ cls_rehersal_prob= torch.exp(torch.tensor(class_error_rates[cls_i]/self.temperature))
|
|
|
+ total_seen_for_cls = len(self.cls_idx_dict[cls_i])
|
|
|
+ if total_seen_for_cls!=0:
|
|
|
+ idx_rehersal_prob[self.cls_idx_dict[cls_i]] = cls_rehersal_prob/total_seen_for_cls
|
|
|
+ else:
|
|
|
+ idx_rehersal_prob[self.cls_idx_dict[cls_i]] = 0
|
|
|
+
|
|
|
+ idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
|
|
|
+ return idxs
|
|
|
+
|
|
|
+ def _get_indices(self, strategy):
|
|
|
+ return self._get_class_error_based_indicies(strategy)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+class WeightedeMovingClassErrorAverageRehersalPlugin(ClassImbalanceRehersalPlugin):
|
|
|
+
|
|
|
+ # The initial class error based sampling method did not improve upon the inverse frequency baseline. After close inspection,
|
|
|
+ # we found that the method was simply to instable, the validation accuracy was hugely dependent on the classes occuring in the most recent
|
|
|
+ # experience. The weighted moving average version of this method stores the validation accuracies from multiple evaluations back and uses a weighted average to
|
|
|
+ # get a better estimation of what classes are being recognised less. The earlier the evaluation, the smaller the weight will be in that average.
|
|
|
+
|
|
|
+ # sigma defines the shape of the gaussian curve from which the weights are sampled from at equal intervals. the smaller the sigma, the steeper the
|
|
|
+ # difference between the weights, the less impact the earlier evaluations on the validation set will have on the average.
|
|
|
+ # nr_of_steps_to_avg defines how many evalutaions of the validation data are used in the averaging.
|
|
|
+
|
|
|
+
|
|
|
+ def __init__(self, buffer_data_ratio, nr_of_steps_to_avg, nr_classes, sigma):
|
|
|
+ print('Initializing WeightedeMovingClassErrorAverageRehersalPlugin' )
|
|
|
+ super(WeightedeMovingClassErrorAverageRehersalPlugin, self).__init__(buffer_data_ratio)
|
|
|
+ self.nr_of_steps_to_avg = nr_of_steps_to_avg
|
|
|
+ self.nr_classes = nr_classes
|
|
|
+ self.sigma = sigma
|
|
|
+ w = signal.gaussian(self.nr_of_steps_to_avg*2, self.sigma)
|
|
|
+ self.w =torch.tensor(w[-self.nr_of_steps_to_avg:])
|
|
|
+ self.recent_cls_error_rates = torch.zeros(( self.nr_of_steps_to_avg, self.nr_classes))
|
|
|
+ self.started_drawing_from_memory = 0 #used to update recent_cls_error_rates only after
|
|
|
+
|
|
|
+ def _get_weighted_class_error_avg_based_indicies(self, strategy):
|
|
|
+
|
|
|
+ nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
|
|
|
+ if len(self.cumulative_dataset)<= nr_indices_to_return:
|
|
|
+ return None
|
|
|
+
|
|
|
+ idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
|
|
|
+
|
|
|
+ if strategy.current_train_exp_seen%5==0 or strategy.current_train_exp_seen == self.buffer_data_ratio:
|
|
|
+ assert strategy.val_cls_acc_dict is not None
|
|
|
+
|
|
|
+ idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
|
|
|
+
|
|
|
+ most_recent_error_rates = torch.zeros((1, self.nr_classes))
|
|
|
+ for cls_i, cls_acc_i in strategy.val_cls_acc_dict.items():
|
|
|
+ total_seen_for_cls = len(idx_rehersal_prob[self.cls_idx_dict[cls_i]])
|
|
|
+ if total_seen_for_cls >0:
|
|
|
+ if isinstance(cls_acc_i.result(), dict) and 0 in cls_acc_i.result().keys() :
|
|
|
+ cls_acc = cls_acc_i.result()[0]
|
|
|
+ cls_error_rate = 1-cls_acc
|
|
|
+ else:
|
|
|
+ cls_error_rate = 1
|
|
|
+ else:
|
|
|
+ cls_error_rate = 0
|
|
|
+
|
|
|
+ most_recent_error_rates[0, cls_i] = cls_error_rate
|
|
|
+
|
|
|
+ most_recent_error_rates = most_recent_error_rates/torch.sum(most_recent_error_rates)
|
|
|
+ self.recent_cls_error_rates = torch.concat([most_recent_error_rates, self.recent_cls_error_rates])
|
|
|
+ self.recent_cls_error_rates = self.recent_cls_error_rates[:-1, :]
|
|
|
+ self.weighted_recent_cls_error_rates = (self.w*self.recent_cls_error_rates.T).T
|
|
|
+ self.weighted_avg_error_rates = torch.mean(self.weighted_recent_cls_error_rates, dim=0, dtype=torch.float)
|
|
|
+ self.weighted_avg_error_rates = self.weighted_avg_error_rates/torch.sum(self.weighted_recent_cls_error_rates)
|
|
|
+
|
|
|
+
|
|
|
+ for cls_i in range(self.nr_classes):
|
|
|
+ total_seen_for_cls = len(self.cls_idx_dict[cls_i])
|
|
|
+
|
|
|
+ if total_seen_for_cls!=0:
|
|
|
+ idx_rehersal_prob[self.cls_idx_dict[cls_i]] = self.weighted_avg_error_rates[cls_i]/total_seen_for_cls
|
|
|
+
|
|
|
+
|
|
|
+ idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
|
|
|
+ return idxs
|
|
|
+
|
|
|
+ def _get_indices(self, strategy):
|
|
|
+ return self._get_weighted_class_error_avg_based_indicies(strategy)
|
|
|
+
|
|
|
+
|
|
|
+class ClassErrorFrequencyAvgRehearsalPlugin(ClassImbalanceRehersalPlugin):
|
|
|
+
|
|
|
+ # This Rehearsal plugin is a hybrid method between the standar class inveres frequency methd and the class error based method.
|
|
|
+ # The class inverse frequencies and class error rates are each normalized across all classes and then simply averaged
|
|
|
+ # and used to assign probabilities to the classes of being picked for rehearsal sampling.
|
|
|
+
|
|
|
+ def __init__(self, buffer_data_ratio, nr_classes):
|
|
|
+ super(ClassErrorFrequencyAvgRehearsalPlugin, self).__init__(buffer_data_ratio)
|
|
|
+ self.nr_classes = nr_classes
|
|
|
+
|
|
|
+ def _get_frequency_based_indicies(self, strategy, **kwargs):
|
|
|
+
|
|
|
+ nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
|
|
|
+ if len(self.cumulative_dataset)<= nr_indices_to_return:
|
|
|
+ return None
|
|
|
+
|
|
|
+ assert strategy.val_cls_acc_dict is not None
|
|
|
+
|
|
|
+ idx_rehersal_prob = torch.zeros((len(self.cumulative_dataset)))
|
|
|
+ total_data_len = len(self.cumulative_dataset)
|
|
|
+ most_recent_error_rates = torch.zeros((1, self.nr_classes))
|
|
|
+ cls_inv_freq = torch.zeros((1, self.nr_classes))
|
|
|
+ for cls_i, cls_acc_i in strategy.val_cls_acc_dict.items():
|
|
|
+ total_seen_for_cls = len(idx_rehersal_prob[self.cls_idx_dict[cls_i]])
|
|
|
+ if total_seen_for_cls >0:
|
|
|
+ if isinstance(cls_acc_i.result(), dict) and 0 in cls_acc_i.result().keys() :
|
|
|
+ cls_acc = cls_acc_i.result()[0]
|
|
|
+
|
|
|
+ cls_error_rate = 1-cls_acc
|
|
|
+ else: # this caseonly occurs when the validation set does not contain all classes.
|
|
|
+ cls_error_rate = 1
|
|
|
+ else:
|
|
|
+ cls_error_rate = 0
|
|
|
+ most_recent_error_rates[0,cls_i] = cls_error_rate
|
|
|
+ for cls_i, c_idxs in self.cls_idx_dict.items():
|
|
|
+ if self.cls_idx_dict[cls_i].shape[0] != 0:
|
|
|
+ inv_freq = 1/c_idxs.shape[0]
|
|
|
+ else:
|
|
|
+ inv_freq = 0
|
|
|
+
|
|
|
+ cls_inv_freq[0,cls_i] = inv_freq
|
|
|
+
|
|
|
+
|
|
|
+ most_recent_error_rates = most_recent_error_rates/torch.sum(most_recent_error_rates)
|
|
|
+ cls_inv_freq = cls_inv_freq/torch.sum(cls_inv_freq)
|
|
|
+
|
|
|
+ elementwise_sum = most_recent_error_rates+cls_inv_freq
|
|
|
+
|
|
|
+ averaged_weights = elementwise_sum/torch.sum(elementwise_sum)
|
|
|
+
|
|
|
+ for cls_i in range(self.nr_classes):
|
|
|
+ total_seen_for_cls = len(self.cls_idx_dict[cls_i])
|
|
|
+ if total_seen_for_cls != 0 :
|
|
|
+ idx_rehersal_prob[self.cls_idx_dict[cls_i]] = averaged_weights[0, cls_i]/total_seen_for_cls
|
|
|
+
|
|
|
+
|
|
|
+ idxs = idx_rehersal_prob.multinomial(num_samples=nr_indices_to_return, replacement=False)
|
|
|
+ return idxs
|
|
|
+
|
|
|
+ def _get_indices(self, strategy):
|
|
|
+ return self._get_frequency_based_indicies(strategy)
|
|
|
+
|
|
|
+
|
|
|
+class FillExpBasedRehearsalPlugin(ClassImbalanceRehersalPlugin):
|
|
|
+ # With this rehearsal strategy the number of images in the finetune set (experience data plus selected memory data, i.e. rehearsal set)
|
|
|
+ # should be roughly the same for each class. This can only roughly be the case, because sometimes there is a class in the
|
|
|
+ # current experience, that is bigger than the total number of images per class should be if the finetune set is devided equally.
|
|
|
+ # As oppose tho the FillExpOversampleRehearsalPlugin, this Plugin does not use oversampling of memory or experience data to achieve
|
|
|
+ # the finetuning set that is as close as possible to a balanced dataset as possible. Here the 'empty' slots that can not be filled beccause a class
|
|
|
+ # does not have enough images in the memory or current experience are filled by drawing randomly from the memeory data that is left after the
|
|
|
+ # classes have been used to fill the rehearsal set. This has the effect that the finetuneing set is as balanced as possible when constraining the rehearsal
|
|
|
+ # set to unique instances.
|
|
|
+ #
|
|
|
+ # First, the number of images that should be in the finetune set of each class if it was perfectly balanced is calculated (img_per_cls)
|
|
|
+ # based on the total number of unique classes in the memory (cls_idx_dict) and the unique classes in the current experience.
|
|
|
+ # All classes unique classes in memory and current experience are unified in the seen_cls list.
|
|
|
+ # The classes from that list, that are 'full', i.e. have more images than img_per_cls, are removed to form the non_full_classes list.
|
|
|
+ # This list is the list of classes the rehearsal set will have.
|
|
|
+ # The img_per_cls number is updated to consider the impact the classes in the experience have, that have more instances than the optimal,
|
|
|
+ # perfectly devided finetune set.
|
|
|
+
|
|
|
+ # For each of the classes in not_full_cls the data from the memory for that class is sampled without replacement. If there are not enough
|
|
|
+ # instances in the memory for that class, the number of missing instances are accumilated in nr_remaining.
|
|
|
+ # At the end, the memory instances that fave not been selected yet, are used to sample nr_remaing ixs at random.
|
|
|
+ def __init__(self, buffer_data_ratio, nr_classes):
|
|
|
+ super(FillExpBasedRehearsalPlugin, self).__init__(buffer_data_ratio)
|
|
|
+ self.nr_classes = nr_classes
|
|
|
+
|
|
|
+ def _get_experience_based_indicies(self, strategy, **kwargs):
|
|
|
+
|
|
|
+ nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
|
|
|
+ if len(self.cumulative_dataset)<= nr_indices_to_return:
|
|
|
+ return None
|
|
|
+
|
|
|
+ targets_in_exp = torch.tensor(strategy.experience.dataset.targets)
|
|
|
+ unique, counts = torch.unique(targets_in_exp, return_counts=True)
|
|
|
+ idxs = torch.tensor([])
|
|
|
+ seen_cls = []
|
|
|
+ for cls_i, c_idxs in self.cls_idx_dict.items():
|
|
|
+ if (c_idxs.shape[0] == 0 and cls_i in unique) or c_idxs.shape[0] > 0 :
|
|
|
+ seen_cls.append(cls_i)
|
|
|
+ exp_counts = {}
|
|
|
+ for i, u in enumerate(unique):
|
|
|
+ exp_counts[u.item()]= counts[i]
|
|
|
+
|
|
|
+ img_per_cls =int((nr_indices_to_return+targets_in_exp.shape[0])/len(seen_cls))
|
|
|
+ counts_over = counts[counts>=img_per_cls]
|
|
|
+ counts_over_total = torch.sum(counts_over)
|
|
|
+
|
|
|
+ nr_not_full_classes =len(seen_cls)- torch.sum(counts>=img_per_cls)
|
|
|
+ seen_cls = torch.tensor(seen_cls)
|
|
|
+ not_full_classes = seen_cls[torch.logical_not(torch.isin(seen_cls, unique[counts>=img_per_cls]))]
|
|
|
+
|
|
|
+ img_per_cls =int( img_per_cls-(counts_over_total/nr_not_full_classes))
|
|
|
+ img_per_cls =int(torch.floor((nr_indices_to_return+targets_in_exp.shape[0]-counts_over_total)/len(not_full_classes)))
|
|
|
+
|
|
|
+ nr_remaining = 0
|
|
|
+ for cls_i, c_idxs in self.cls_idx_dict.items():
|
|
|
+
|
|
|
+ if cls_i in not_full_classes:
|
|
|
+ if cls_i in unique: # class in experience
|
|
|
+
|
|
|
+ count_cls = exp_counts[cls_i]
|
|
|
+
|
|
|
+ img_left_to_draw_for_cls_i =img_per_cls-count_cls
|
|
|
+ else: # one of the already seen classes not in current experience
|
|
|
+ img_left_to_draw_for_cls_i = img_per_cls
|
|
|
+
|
|
|
+ if img_left_to_draw_for_cls_i <=0: # occurs when the updated img_per_cls is smaller than some of the number of instances in the experience
|
|
|
+ cls__i_idxs = torch.tensor([])
|
|
|
+ elif c_idxs.shape[0] > img_left_to_draw_for_cls_i:
|
|
|
+ indices = torch.randperm(c_idxs.shape[0])[:img_left_to_draw_for_cls_i]
|
|
|
+ cls__i_idxs = c_idxs[indices]
|
|
|
+ elif c_idxs.shape[0]>0 and c_idxs.shape[0]<= img_left_to_draw_for_cls_i:
|
|
|
+ cls__i_idxs = c_idxs
|
|
|
+ nr_remaining += (img_left_to_draw_for_cls_i-c_idxs.shape[0])
|
|
|
+ else: # occures when there is a new class, that has not been added to the memory
|
|
|
+ cls__i_idxs = torch.tensor([])
|
|
|
+ nr_remaining += img_left_to_draw_for_cls_i
|
|
|
+
|
|
|
+ idxs = torch.concat([idxs, cls__i_idxs])
|
|
|
+
|
|
|
+ idxs = torch.tensor(idxs, dtype=int)
|
|
|
+ nr_idxs_to_draw_randomly = nr_indices_to_return-idxs.shape[0]
|
|
|
+ if nr_idxs_to_draw_randomly >0:
|
|
|
+
|
|
|
+ cumulative_dataset_idxs = torch.arange(0, len(self.cumulative_dataset))
|
|
|
+ cumulative_dataset_idxs_remaining = cumulative_dataset_idxs[torch.logical_not(torch.isin(cumulative_dataset_idxs, idxs))]
|
|
|
+ pos_picked = torch.randperm(len(cumulative_dataset_idxs_remaining))[:nr_idxs_to_draw_randomly]
|
|
|
+ randomly_picked_idxs = cumulative_dataset_idxs_remaining[pos_picked]
|
|
|
+ idxs = torch.concat([idxs, randomly_picked_idxs])
|
|
|
+ idxs = torch.tensor(idxs, dtype=int)
|
|
|
+ if idxs.shape[0]!= nr_indices_to_return:
|
|
|
+ print('ATTENTION: idxs does not have the correct length')
|
|
|
+ print(len(idxs))
|
|
|
+ print(nr_indices_to_return)
|
|
|
+ return idxs
|
|
|
+
|
|
|
+ def _get_indices(self, strategy):
|
|
|
+ return self._get_experience_based_indicies(strategy)
|
|
|
+
|
|
|
+
|
|
|
+class FillExpOversampleBasedRehearsalPlugin(ClassImbalanceRehersalPlugin):
|
|
|
+ # With this rehearsal strategy the number of images in the finetune set (experience data plus selected memory data, i.e. rehearsal set)
|
|
|
+ # should be roughly the same for each class. This can only roughly be the case, because sometimes there is a class in the
|
|
|
+ # current experience, that is bigger than the total number of images per class should be if the finetune set is devided equally.
|
|
|
+ #
|
|
|
+ # First, the number of images that should be in the finetune set of each class if it was perfectly balanced is calculated (img_per_cls)
|
|
|
+ # based on the total number of unique classes in the memory (cls_idx_dict) and the unique classes in the current experience.
|
|
|
+ # All classes unique classes in memory and current experience are unified in the seen_cls list.
|
|
|
+ # The classes from that list, that are 'full', i.e. have more images than img_per_cls, are removed to form the non_full_classes list.
|
|
|
+ # This list is the list of classes the rehearsal set will have.
|
|
|
+ # The img_per_cls number is updated to consider the impact the classes in the experience have, that have mor instances than the optimal,
|
|
|
+ # perfectly devided finetune set.
|
|
|
+ # In FillExpOversample the classes will be filled by sampling or oversampling
|
|
|
+ # from the memory (cls_idx_dict). The indecies of the memory used for the rehearsal set arecollected in the idxs tensor.
|
|
|
+ # In istances when the current experience has a new class that had not occured previously in the stream, the experience itself
|
|
|
+ # is oversampled to ensure the finetune-set has the best possible class balance. This is implemented with the exp_oversample_idxs tensor.
|
|
|
+
|
|
|
+ # Both tensors are returned and the train_dataset_adaptation function in the mother class ClassImbalanceRehearsalPlugin handles the them
|
|
|
+ # to build a rehearsal set.
|
|
|
+
|
|
|
+ def __init__(self, buffer_data_ratio, nr_classes):
|
|
|
+ super(FillExpOversampleBasedRehearsalPlugin, self).__init__(buffer_data_ratio)
|
|
|
+ self.nr_classes = nr_classes
|
|
|
+
|
|
|
+ def _get_experience_based_indicies(self, strategy, **kwargs):
|
|
|
+
|
|
|
+ nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
|
|
|
+ if len(self.cumulative_dataset)<= nr_indices_to_return:
|
|
|
+ return None
|
|
|
+
|
|
|
+ targets_in_exp = torch.tensor(strategy.experience.dataset.targets)
|
|
|
+ unique, counts = torch.unique(targets_in_exp, return_counts=True)
|
|
|
+ idxs = torch.tensor([])
|
|
|
+ exp_oversample_idxs = torch.tensor([], )
|
|
|
+ seen_cls = []
|
|
|
+ for cls_i, c_idxs in self.cls_idx_dict.items():
|
|
|
+ if (c_idxs.shape[0] == 0 and cls_i in unique) or c_idxs.shape[0] > 0 :
|
|
|
+ seen_cls.append(cls_i)
|
|
|
+ exp_counts = {}
|
|
|
+ for i, u in enumerate(unique):
|
|
|
+ exp_counts[u.item()]= counts[i]
|
|
|
+
|
|
|
+ img_per_cls =int((nr_indices_to_return+targets_in_exp.shape[0])/len(seen_cls))
|
|
|
+ counts_over = counts[counts>=img_per_cls]
|
|
|
+ counts_over_total = torch.sum(counts_over)
|
|
|
+
|
|
|
+ nr_not_full_classes =len(seen_cls)- torch.sum(counts>=img_per_cls)
|
|
|
+ seen_cls = torch.tensor(seen_cls)
|
|
|
+ not_full_classes = seen_cls[torch.logical_not(torch.isin(seen_cls, unique[counts>=img_per_cls]))]
|
|
|
+
|
|
|
+ img_per_cls =int( img_per_cls-(counts_over_total/nr_not_full_classes))
|
|
|
+ img_per_cls =int(torch.ceil((nr_indices_to_return+targets_in_exp.shape[0]-counts_over_total)/len(not_full_classes)))
|
|
|
+
|
|
|
+ nr_remaining = 0
|
|
|
+ for cls_i, c_idxs in self.cls_idx_dict.items():
|
|
|
+
|
|
|
+ if cls_i in not_full_classes:
|
|
|
+ if cls_i in unique: # if class is in experience
|
|
|
+
|
|
|
+ count_cls = exp_counts[cls_i]
|
|
|
+
|
|
|
+ img_left_to_draw_for_cls_i =img_per_cls-count_cls
|
|
|
+ else:
|
|
|
+ img_left_to_draw_for_cls_i = img_per_cls
|
|
|
+
|
|
|
+
|
|
|
+ if img_left_to_draw_for_cls_i<=0: # occurs when the updated img_per_cls is smaller than some of the number of instances in the experience
|
|
|
+ cls__i_idxs=torch.tensor([])
|
|
|
+ elif c_idxs.shape[0] > img_left_to_draw_for_cls_i:
|
|
|
+ indices = torch.randperm(c_idxs.shape[0])[:img_left_to_draw_for_cls_i] # randomly select memory images from the class
|
|
|
+ cls__i_idxs = c_idxs[indices]
|
|
|
+ elif c_idxs.shape[0]>0 and c_idxs.shape[0]<= img_left_to_draw_for_cls_i: #
|
|
|
+ c_idxs_pos_idx = torch.multinomial(torch.ones((c_idxs.shape[0], )), img_left_to_draw_for_cls_i, replacement=True) # draw with replacement to fill the number of idxs
|
|
|
+ cls__i_idxs = c_idxs[c_idxs_pos_idx]
|
|
|
+
|
|
|
+ else: # this case applies when the class appears for the first time, meaning there are no images in the memory yet
|
|
|
+
|
|
|
+ cls_i_exp_idxs = torch.where(targets_in_exp==cls_i)[0]
|
|
|
+ cls_i_exp_pos_idx = torch.multinomial(torch.ones((cls_i_exp_idxs.shape[0], )), img_left_to_draw_for_cls_i, replacement=True)
|
|
|
+ exp_oversample_idxs = torch.concat([exp_oversample_idxs, cls_i_exp_idxs[cls_i_exp_pos_idx]])
|
|
|
+ cls__i_idxs = torch.tensor([])
|
|
|
+
|
|
|
+
|
|
|
+ idxs = torch.concat([idxs, cls__i_idxs])
|
|
|
+
|
|
|
+ idxs = torch.tensor(idxs, dtype=int)
|
|
|
+ exp_oversample_idxs = torch.tensor(exp_oversample_idxs, dtype=int)
|
|
|
+
|
|
|
+ idxs_selected_for_return = idxs.shape[0]+exp_oversample_idxs.shape[0]
|
|
|
+ nr_idxs_to_draw_randomly = nr_indices_to_return-idxs_selected_for_return
|
|
|
+
|
|
|
+ # handling edge casers where total indicies returned is larger than exact number of indicies needed in rehearsal set
|
|
|
+ # This occures due to rounding errors.
|
|
|
+ if idxs_selected_for_return> nr_indices_to_return:
|
|
|
+
|
|
|
+ diff = idxs_selected_for_return-nr_indices_to_return
|
|
|
+ if len(idxs)>len(exp_oversample_idxs):
|
|
|
+ idxs_drop_set = torch.randperm(idxs.shape[0])[:(idxs.shape[0]-diff)]
|
|
|
+ idxs= idxs[idxs_drop_set]
|
|
|
+ else:
|
|
|
+ idxs_drop_set = torch.randperm(exp_oversample_idxs.shape[0])[:(exp_oversample_idxs.shape[0]-diff)]
|
|
|
+ exp_oversample_idxs[idxs_drop_set]
|
|
|
+
|
|
|
+ idxs_selected_for_return = idxs.shape[0]+exp_oversample_idxs.shape[0]
|
|
|
+
|
|
|
+
|
|
|
+ return [idxs, exp_oversample_idxs]
|
|
|
+
|
|
|
+ def _get_indices(self, strategy):
|
|
|
+ return self._get_experience_based_indicies(strategy)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+class RandomRehersal(ClassImbalanceRehersalPlugin):
|
|
|
+ # This rehearsal strategy simply draws uzniformly from the memory for reahearsal.
|
|
|
+ # No class specific information is used.
|
|
|
+
|
|
|
+ def __init__(self, buffer_data_ratio):
|
|
|
+ super(RandomRehersal, self).__init__(buffer_data_ratio)
|
|
|
+
|
|
|
+ def _get_class_random_indicies(self, strategy):
|
|
|
+ nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
|
|
|
+ if len(self.cumulative_dataset)<= nr_indices_to_return:
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ idxs = random.sample(range(len(self.cumulative_dataset)), nr_indices_to_return)
|
|
|
+ idxs = torch.tensor(idxs)
|
|
|
+
|
|
|
+ return idxs
|
|
|
+
|
|
|
+ def _get_indices(self, strategy):
|
|
|
+ return self._get_class_random_indicies(strategy)
|
|
|
+
|
|
|
+
|
|
|
+class MaximallyInterferedRetrievalRehersalPlugin(ClassImbalanceRehersalPlugin):
|
|
|
+
|
|
|
+ # This is a rehearsal strategy based on a paper called 'Online Continual Learning with Maximally Interfered Retrieval' by Aljundi et. al.
|
|
|
+ # Their idea was to select images from the memory, where the loss increases most, if the model was finetuned only with the current experience.
|
|
|
+ # A randomly selected subset of the memory is evaluated with the current model, then a virtual update step is performed by finetuning on the
|
|
|
+ # current experience, then the new model is again used to evaluate the memory subset. The instances where the impact on the loss was the
|
|
|
+ # largest are selected for the rehearsal set. The rehearsal set and the current experience together are then used to finetune the model
|
|
|
+ # (without virtual update) and perform the actual.
|
|
|
+
|
|
|
+ # This method was not specifically targeted to handle class imbalance data. However it is one of the few methods that target rehearsal learning
|
|
|
+ # by strategically selecting images from the memory, which is why it was used to compare our methods.
|
|
|
+
|
|
|
+
|
|
|
+ def __init__(self, buffer_data_ratio, candidate_subset_size=10):
|
|
|
+ self.eval_criterion = LabelSmoothingCrossEntropy(reduction='none')
|
|
|
+ self.criterion = LabelSmoothingCrossEntropy()
|
|
|
+ self.candidate_subset_size = candidate_subset_size
|
|
|
+ super(MaximallyInterferedRetrievalRehersalPlugin, self).__init__(buffer_data_ratio=buffer_data_ratio)
|
|
|
+
|
|
|
+
|
|
|
+ def _unpack_minibatch_MIR(self, mbatch, strategy):
|
|
|
+
|
|
|
+ paths = mbatch[1][0]
|
|
|
+ mbatch = mbatch[0]
|
|
|
+
|
|
|
+ assert len(mbatch) >= 3
|
|
|
+
|
|
|
+ input_imgs = mbatch[0].to(strategy.device)
|
|
|
+ targets = mbatch[1].to(strategy.device)
|
|
|
+ return paths, input_imgs, targets
|
|
|
+
|
|
|
+ def _get_MIR_based_indicies(self, strategy):
|
|
|
+ logging_mir_dict ={}
|
|
|
+ nr_indices_to_return = int(len(strategy.experience.dataset)*self.buffer_data_ratio)
|
|
|
+ if len(self.cumulative_dataset)<= nr_indices_to_return:
|
|
|
+ return None
|
|
|
+ candidate_subset_idxs = sorted(random.sample(range(len(self.cumulative_dataset)), min(nr_indices_to_return,len(self.cumulative_dataset))))
|
|
|
+ candidate_subset = AvalancheSubset(self.cumulative_dataset, candidate_subset_idxs)
|
|
|
+ original_model_weights = copy.deepcopy(strategy.model.state_dict())
|
|
|
+ logging_mir_dict['candidate_subset_idxs'] = candidate_subset_idxs
|
|
|
+
|
|
|
+ eps=1e-4
|
|
|
+ weight_decay=75e-3
|
|
|
+ self.optimizer = AdamW(strategy.model.parameters(), eps=eps, weight_decay=weight_decay)
|
|
|
+ data_loader = TaskBalancedDataLoader(
|
|
|
+ candidate_subset,
|
|
|
+ collate_mbatches=_seq_collate_mbatches_fn,
|
|
|
+ batch_size=strategy.train_mb_size,
|
|
|
+ shuffle=False,
|
|
|
+ pin_memory=False)
|
|
|
+ strategy.model.eval()
|
|
|
+ all_losses_pre_update = []
|
|
|
+ all_paths = []
|
|
|
+ with torch.no_grad():
|
|
|
+ for mbatch in data_loader:
|
|
|
+
|
|
|
+ paths, input_imgs, targets = self._unpack_minibatch_MIR(mbatch, strategy)
|
|
|
+
|
|
|
+ all_paths = all_paths+ np.array(paths).tolist()
|
|
|
+ mb_output = strategy.model(input_imgs)
|
|
|
+ pre_update_loss = self.eval_criterion(mb_output, targets)
|
|
|
+ all_losses_pre_update = all_losses_pre_update+ pre_update_loss.cpu().detach().numpy().tolist()
|
|
|
+
|
|
|
+ logging_mir_dict['all_losses_pre_update'] = all_losses_pre_update
|
|
|
+
|
|
|
+ # virtual update step
|
|
|
+ experience_data_loader = TaskBalancedDataLoader(
|
|
|
+ strategy.experience.dataset,
|
|
|
+ collate_mbatches=_seq_collate_mbatches_fn,
|
|
|
+ batch_size=strategy.train_mb_size,
|
|
|
+ shuffle=True,
|
|
|
+ pin_memory=False)
|
|
|
+
|
|
|
+ strategy.model.train()
|
|
|
+ reset_optimizer(self.optimizer, strategy.model)
|
|
|
+
|
|
|
+ for mbatch in experience_data_loader:
|
|
|
+ _, input_imgs, targets = self._unpack_minibatch_MIR(mbatch, strategy)
|
|
|
+ self.optimizer.zero_grad()
|
|
|
+ mb_output = strategy.model(input_imgs)
|
|
|
+ loss = self.criterion(mb_output, targets)
|
|
|
+ loss.backward()
|
|
|
+ self.optimizer.step()
|
|
|
+
|
|
|
+ data_loader = TaskBalancedDataLoader(
|
|
|
+ candidate_subset,
|
|
|
+ collate_mbatches=_seq_collate_mbatches_fn,
|
|
|
+ batch_size=strategy.train_mb_size,
|
|
|
+ shuffle=False,
|
|
|
+ pin_memory=False)
|
|
|
+ strategy.model.eval()
|
|
|
+ all_losses_after_update = []
|
|
|
+ with torch.no_grad():
|
|
|
+ for mbatch in data_loader:
|
|
|
+
|
|
|
+ paths, input_imgs, targets = self._unpack_minibatch_MIR(mbatch, strategy)
|
|
|
+ mb_output = strategy.model(input_imgs)
|
|
|
+
|
|
|
+ after_update_loss = self.eval_criterion(mb_output, targets)
|
|
|
+
|
|
|
+ all_losses_after_update = all_losses_after_update + after_update_loss.cpu().detach().numpy().tolist()
|
|
|
+ logging_mir_dict['all_losses_after_update'] = all_losses_after_update
|
|
|
+
|
|
|
+ differences = (np.array(all_losses_pre_update) -np.array(all_losses_after_update)).tolist()
|
|
|
+
|
|
|
+ from_candidates_idxs =sorted(range(len(differences)), key=differences.__getitem__, reverse=True)
|
|
|
+
|
|
|
+ idx_ordering = np.argsort(-np.array(differences))
|
|
|
+ reordered_paths = np.array(all_paths)[idx_ordering]
|
|
|
+
|
|
|
+ strategy.mir_losses_dict[strategy.experience.current_experience] = logging_mir_dict
|
|
|
+ candidate_subset_idxs = np.array(candidate_subset_idxs)
|
|
|
+ idxs = candidate_subset_idxs[from_candidates_idxs][:nr_indices_to_return]
|
|
|
+ strategy.model.load_state_dict(original_model_weights)
|
|
|
+
|
|
|
+ return idxs
|
|
|
+
|
|
|
+
|
|
|
+ def _get_indices(self, strategy):
|
|
|
+ return self._get_MIR_based_indicies(strategy)
|