123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414 |
- import warnings
- from copy import copy
- from collections import defaultdict
- from typing import Union, Sequence, TYPE_CHECKING
- from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
- from avalanche.training.plugins.strategy_plugin import StrategyPlugin
- from avalanche.logging import StrategyLogger, InteractiveLogger, CSVLogger, GenericCSVLogger
- if TYPE_CHECKING:
- from avalanche.evaluation import PluginMetric
- from avalanche.logging import StrategyLogger
- from avalanche.training.strategies import BaseStrategy
- import matplotlib.pyplot as plt
- import numpy as np
- import torch
- import pickle as pkl
- import datetime
- class EvaluationPlugin(StrategyPlugin):
- """
- An evaluation plugin that obtains relevant data from the
- training and eval loops of the strategy through callbacks.
- The plugin keeps a dictionary with the last recorded value for each metric.
- The dictionary will be returned by the `train` and `eval` methods of the
- strategies.
- It is also possible to keep a dictionary with all recorded metrics by
- specifying `collect_all=True`. The dictionary can be retrieved via
- the `get_all_metrics` method.
- This plugin also logs metrics using the provided loggers.
- """
- def __init__(self,
- *metrics: Union['PluginMetric', Sequence['PluginMetric']],
- loggers: Union['StrategyLogger', Sequence['StrategyLogger']] = None,
- collect_all=True,
- benchmark=None,
- strict_checks=False,
- suppress_warnings=False):
- """
- Creates an instance of the evaluation plugin.
- :param metrics: The metrics to compute.
- :param loggers: The loggers to be used to log the metric values.
- :param collect_all: if True, collect in a separate dictionary all
- metric curves values. This dictionary is accessible with
- `get_all_metrics` method.
- :param benchmark: continual learning benchmark needed to check stream
- completeness during evaluation or other kind of properties. If
- None, no check will be conducted and the plugin will emit a
- warning to signal this fact.
- :param strict_checks: if True, `benchmark` has to be provided.
- In this case, only full evaluation streams are admitted when
- calling `eval`. An error will be raised otherwise. When False,
- `benchmark` can be `None` and only warnings will be raised.
- :param suppress_warnings: if True, warnings and errors will never be
- raised from the plugin.
- If False, warnings and errors will be raised following
- `benchmark` and `strict_checks` behavior.
- """
- super().__init__()
- self.collect_all = collect_all
- self.benchmark = benchmark
- self.strict_checks = strict_checks
- self.suppress_warnings = suppress_warnings
- flat_metrics_list = []
- time_metric = None
- for metric in metrics:
- if isinstance(metric, Sequence):
- for m in metric:
- if 'Time' in str(m):
- time_metric = m
- else:
- flat_metrics_list.append(m)
- else:
- if 'Time' in str(metric):
- time_metric = m
- else:
- flat_metrics_list.append(metric)
- if time_metric != None:
- flat_metrics_list.append(time_metric)
- self.metrics = flat_metrics_list
- if loggers is None:
- loggers = []
- elif not isinstance(loggers, Sequence):
- loggers = [loggers]
- if benchmark is None:
- if not suppress_warnings:
- if strict_checks:
- raise ValueError("Benchmark cannot be None "
- "in strict mode.")
- else:
- warnings.warn(
- "No benchmark provided to the evaluation plugin. "
- "Metrics may be computed on inconsistent portion "
- "of streams, use at your own risk.")
- else:
- self.complete_test_stream = benchmark.test_stream
- self.loggers: Sequence['StrategyLogger'] = loggers
- if len(self.loggers) == 0:
- warnings.warn('No loggers specified, metrics will not be logged')
- for logger in loggers:
- if isinstance(logger, CSVLogger) or isinstance(logger, GenericCSVLogger):
- #print(logger)
- logger.all_metrics = self.metrics
- if self.collect_all:
- # for each curve collect all emitted values.
- # dictionary key is full metric name.
- # Dictionary value is a tuple of two lists.
- # first list gathers x values (indices representing
- # time steps at which the corresponding metric value
- # has been emitted)
- # second list gathers metric values
- self.all_metric_results = defaultdict(lambda: ([], []))
- # Dictionary of last values emitted. Dictionary key
- # is the full metric name, while dictionary value is
- # metric value.
- self.last_metric_results = {}
- self._active = True
- """If True, no metrics will be collected."""
- @property
- def active(self):
- return self._active
- @active.setter
- def active(self, value):
- assert value is True or value is False, \
- "Active must be set as either True or False"
- self._active = value
- def _update_metrics(self, strategy: 'BaseStrategy', callback: str):
- if not self._active:
- return []
- metric_values = []
- non_none_counter = 0
- for metric in self.metrics:
- metric_result = getattr(metric, callback)(strategy)
- if isinstance(metric_result, Sequence):
- non_none_counter +=1
- metric_values += list(metric_result)
- elif metric_result is not None:
- non_none_counter +=1
- metric_values.append(metric_result)
-
- for metric_value in metric_values:
- name = metric_value.name
- x = metric_value.x_plot
- val = metric_value.value
- if self.collect_all:
- self.all_metric_results[name][0].append(x)
- self.all_metric_results[name][1].append(val)
- self.last_metric_results[name] = val
- for logger in self.loggers:
- getattr(logger, callback)(strategy, metric_values)
- return metric_values
- def get_last_metrics(self):
- """
- Return a shallow copy of dictionary with metric names
- as keys and last metrics value as values.
- :return: a dictionary with full metric
- names as keys and last metric value as value.
- """
- return copy(self.last_metric_results)
- def get_all_metrics(self):
- """
- Return the dictionary of all collected metrics.
- This method should be called only when `collect_all` is set to True.
- :return: if `collect_all` is True, returns a dictionary
- with full metric names as keys and a tuple of two lists
- as value. The first list gathers x values (indices
- representing time steps at which the corresponding
- metric value has been emitted). The second list
- gathers metric values. a dictionary. If `collect_all`
- is False return an empty dictionary
- """
- if self.collect_all:
- return self.all_metric_results
- else:
- return {}
- def reset_last_metrics(self):
- """
- Set the dictionary storing last value for each metric to be
- empty dict.
- """
- self.last_metric_results = {}
- def before_training(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'before_training')
- def before_training_exp(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'before_training_exp')
- def before_train_dataset_adaptation(self, strategy: 'BaseStrategy',
- **kwargs):
- self._update_metrics(strategy, 'before_train_dataset_adaptation')
- def after_train_dataset_adaptation(self, strategy: 'BaseStrategy',
- **kwargs):
- self._update_metrics(strategy, 'after_train_dataset_adaptation')
- def before_training_epoch(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'before_training_epoch')
- def before_training_iteration(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'before_training_iteration')
- def before_forward(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'before_forward')
- def after_forward(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_forward')
- def before_backward(self, strategy: 'BaseStrategy', **kwargs):
- self.update_metrics = self._update_metrics(strategy, 'before_backward')
- def after_backward(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_backward')
- def after_training_iteration(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_training_iteration')
- def before_update(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'before_update')
- def after_update(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_update')
- def after_training_epoch(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_training_epoch')
- def after_training_exp(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_training_exp')
- def after_training(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_training')
-
- def before_eval(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'before_eval')
- msgw = "Evaluation stream is not equal to the complete test stream. " \
- "This may result in inconsistent metrics. Use at your own risk."
- msge = "Stream provided to `eval` must be the same of the entire " \
- "evaluation stream."
- if self.benchmark is not None:
- for i, exp in enumerate(self.complete_test_stream):
- try:
- current_exp = strategy.current_eval_stream[i]
- if exp.current_experience != current_exp.current_experience:
- if not self.suppress_warnings:
- if self.strict_checks:
- raise ValueError(msge)
- else:
- warnings.warn(msgw)
- except IndexError:
- if self.strict_checks:
- raise ValueError(msge)
- else:
- warnings.warn(msgw)
- def before_eval_dataset_adaptation(self, strategy: 'BaseStrategy',
- **kwargs):
- self._update_metrics(strategy, 'before_eval_dataset_adaptation')
- def after_eval_dataset_adaptation(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_eval_dataset_adaptation')
- def before_eval_exp(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'before_eval_exp')
- def after_eval_exp(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_eval_exp')
- def after_eval(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_eval')
- if strategy.last_eval:
- plt.figure(figsize=(27,18))
- tick_marks = np.arange(strategy.model.num_classes)
- confusion_data_dict = {'tick_marks': tick_marks}
- title = 'Acc: '+ str(np.round(strategy.top1_acc*100, 2))+' ClsAvgAcc: '+str(np.round(strategy.cls_top1_avg_acc*100, 2))
- title = title +'\n'+ 'SeqAcc: '+str(np.round(strategy.seq_top1_acc*100, 2)) + ' SeqClsAcc: '+ str(np.round(strategy.seq_cls_top1_avg_acc*100, 2))
- title = title + ' SeqAnyAcc'+str(np.round(strategy.seq_any_acc*100, 2))
- cmt = np.zeros((strategy.model.num_classes,strategy.model.num_classes))
- strategy.all_preds = strategy.all_preds.cpu()
- strategy.all_targets = strategy.all_targets.cpu()
-
- if strategy.experience.origin_stream.name =='validation':
- file_name = '_after_exp'+str(strategy.current_train_exp_seen)+'_validation'
- class_acc_dict = strategy.val_cls_acc_dict
- if strategy.experience.origin_stream.name =='test':
- file_name = '_after_exp'+str(strategy.current_train_exp_seen)+'_test'
- class_acc_dict = strategy.test_cls_acc_dict
- if strategy.experience.origin_stream.name =='train':
- file_name = '_after_exp'+str(strategy.current_train_exp_seen)+'_train'
- class_acc_dict = strategy.train_cls_acc_dict
-
- confusion_data_dict['title'] = title
- confusion_data_dict['top1_acc'] = strategy.top1_acc
- confusion_data_dict['cls_top1_avg_acc'] = strategy.cls_top1_avg_acc
- confusion_data_dict['seq_top1_acc'] = strategy.seq_top1_acc
- confusion_data_dict['seq_cls_top1_avg_acc'] = strategy.seq_cls_top1_avg_acc
- confusion_data_dict['seq_any_acc'] = strategy.seq_any_acc
- confusion_data_dict['all_preds'] = strategy.all_preds
- confusion_data_dict['all_targets'] = strategy.all_targets
- confusion_data_dict['label_dict'] = strategy.label_dict
- confusion_data_dict['cls_acc_dict'] = class_acc_dict
- confusion_data_dict['unique_train_cls_dict'] = strategy.unique_train_cls_dict
- confusion_data_dict['total_train_cls_dict'] = strategy.total_train_cls_dict
- confusion_data_dict['total_test_cls_dict'] = strategy.total_test_cls_dict
- confusion_data_dict['total_validation_cls_dict'] = strategy.total_validation_cls_dict
- confusion_data_dict['cumulative_dataset_paths']=strategy.cumulative_dataset_paths
- confusion_data_dict['rehearsal_indicies_picked']=strategy.rehearsal_indicies_picked # corresponding to cumulative datset paths
- confusion_data_dict['mir_losses_dict'] = strategy.mir_losses_dict
- confusion_data_dict['memory_dataset_paths']= strategy.memory_dataset_paths
- confusion_data_dict['memory_dataset_targets'] = strategy.memory_dataset_targets
- for pl, tl in zip(strategy.all_preds, strategy.all_targets):
- cmt[int(tl), int(pl)] = cmt[int(tl), int(pl)] + 1
- cmt = cmt.astype('float') / cmt.sum(axis=1)[:, np.newaxis]
- classes = []
- class_labels = []
-
- for i in range(strategy.model.num_classes):
- cls_acc = -1
- if isinstance(class_acc_dict[i].result(), dict) and (0 in class_acc_dict[i].result().keys()):
- if isinstance(class_acc_dict[i].result()[0], dict):
- if 0 in class_acc_dict[i].result()[0].keys():
- cls_acc = class_acc_dict[i].result()[0][0]
- elif isinstance(class_acc_dict[i].result()[0], float):
- cls_acc = class_acc_dict[i].result()[0]
- if i in strategy.label_dict.keys():
- classes.append(strategy.label_dict[i])
- else:
- classes.append('Unknown Class')
- cls_label = strategy.label_dict[i] +''+' cls_acc: '+str(np.round(cls_acc*100,2)) +'\ntrain/cumu/test/validation: '
- cls_label = cls_label +str(self.get_int_value(strategy.unique_train_cls_dict[i]))+'/'+str(self.get_int_value(strategy.total_train_cls_dict[i]))+'/'+str(self.get_int_value(strategy.total_test_cls_dict[i]))+'/'+str(self.get_int_value(strategy.total_validation_cls_dict[i]))
- class_labels.append(cls_label)
- plt.imshow(cmt, interpolation='nearest', cmap=plt.cm.Blues)
- plt.title(title, fontsize=21)
- plt.colorbar()
- plt.xticks(tick_marks, classes, rotation=90, fontsize=18)
- plt.yticks(tick_marks, class_labels, fontsize=18)
- plt.ylabel('True label', fontsize=18)
- plt.xlabel('Predicted label', fontsize=18)
- print(strategy.log_dir)
- plt.tight_layout()
- datetime_stap = str(datetime.datetime.now()).replace(' ','_').replace(':','-').replace('.','-')[:-4]
- plt.savefig(str(strategy.log_dir)+datetime_stap+file_name+'_confusion.png', facecolor='white')
- plt.close()
- with open(str(strategy.log_dir)+datetime_stap+file_name+'_confusion_data_dict.pkl', 'wb') as handle:
- pkl.dump(confusion_data_dict, handle, protocol=pkl.HIGHEST_PROTOCOL)
-
- def get_int_value(self, value):
- if torch.is_tensor(value):
- return value.item()
- return int(value)
- def before_eval_iteration(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'before_eval_iteration')
- def before_eval_forward(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'before_eval_forward')
- def after_eval_forward(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_eval_forward')
- def after_eval_iteration(self, strategy: 'BaseStrategy', **kwargs):
- self._update_metrics(strategy, 'after_eval_iteration')
- default_logger = EvaluationPlugin(
- accuracy_metrics(minibatch=False, epoch=True, experience=True, stream=True),
- loss_metrics(minibatch=False, epoch=True, experience=True, stream=True),
- loggers=[InteractiveLogger()],
- suppress_warnings=True)
- __all__ = [
- 'EvaluationPlugin',
- 'default_logger'
- ]
|