evaluation_plugin.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. import warnings
  2. from copy import copy
  3. from collections import defaultdict
  4. from typing import Union, Sequence, TYPE_CHECKING
  5. from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
  6. from avalanche.training.plugins.strategy_plugin import StrategyPlugin
  7. from avalanche.logging import StrategyLogger, InteractiveLogger, CSVLogger, GenericCSVLogger
  8. if TYPE_CHECKING:
  9. from avalanche.evaluation import PluginMetric
  10. from avalanche.logging import StrategyLogger
  11. from avalanche.training.strategies import BaseStrategy
  12. import matplotlib.pyplot as plt
  13. import numpy as np
  14. import torch
  15. import pickle as pkl
  16. import datetime
  17. class EvaluationPlugin(StrategyPlugin):
  18. """
  19. An evaluation plugin that obtains relevant data from the
  20. training and eval loops of the strategy through callbacks.
  21. The plugin keeps a dictionary with the last recorded value for each metric.
  22. The dictionary will be returned by the `train` and `eval` methods of the
  23. strategies.
  24. It is also possible to keep a dictionary with all recorded metrics by
  25. specifying `collect_all=True`. The dictionary can be retrieved via
  26. the `get_all_metrics` method.
  27. This plugin also logs metrics using the provided loggers.
  28. """
  29. def __init__(self,
  30. *metrics: Union['PluginMetric', Sequence['PluginMetric']],
  31. loggers: Union['StrategyLogger', Sequence['StrategyLogger']] = None,
  32. collect_all=True,
  33. benchmark=None,
  34. strict_checks=False,
  35. suppress_warnings=False):
  36. """
  37. Creates an instance of the evaluation plugin.
  38. :param metrics: The metrics to compute.
  39. :param loggers: The loggers to be used to log the metric values.
  40. :param collect_all: if True, collect in a separate dictionary all
  41. metric curves values. This dictionary is accessible with
  42. `get_all_metrics` method.
  43. :param benchmark: continual learning benchmark needed to check stream
  44. completeness during evaluation or other kind of properties. If
  45. None, no check will be conducted and the plugin will emit a
  46. warning to signal this fact.
  47. :param strict_checks: if True, `benchmark` has to be provided.
  48. In this case, only full evaluation streams are admitted when
  49. calling `eval`. An error will be raised otherwise. When False,
  50. `benchmark` can be `None` and only warnings will be raised.
  51. :param suppress_warnings: if True, warnings and errors will never be
  52. raised from the plugin.
  53. If False, warnings and errors will be raised following
  54. `benchmark` and `strict_checks` behavior.
  55. """
  56. super().__init__()
  57. self.collect_all = collect_all
  58. self.benchmark = benchmark
  59. self.strict_checks = strict_checks
  60. self.suppress_warnings = suppress_warnings
  61. flat_metrics_list = []
  62. time_metric = None
  63. for metric in metrics:
  64. if isinstance(metric, Sequence):
  65. for m in metric:
  66. if 'Time' in str(m):
  67. time_metric = m
  68. else:
  69. flat_metrics_list.append(m)
  70. else:
  71. if 'Time' in str(metric):
  72. time_metric = m
  73. else:
  74. flat_metrics_list.append(metric)
  75. if time_metric != None:
  76. flat_metrics_list.append(time_metric)
  77. self.metrics = flat_metrics_list
  78. if loggers is None:
  79. loggers = []
  80. elif not isinstance(loggers, Sequence):
  81. loggers = [loggers]
  82. if benchmark is None:
  83. if not suppress_warnings:
  84. if strict_checks:
  85. raise ValueError("Benchmark cannot be None "
  86. "in strict mode.")
  87. else:
  88. warnings.warn(
  89. "No benchmark provided to the evaluation plugin. "
  90. "Metrics may be computed on inconsistent portion "
  91. "of streams, use at your own risk.")
  92. else:
  93. self.complete_test_stream = benchmark.test_stream
  94. self.loggers: Sequence['StrategyLogger'] = loggers
  95. if len(self.loggers) == 0:
  96. warnings.warn('No loggers specified, metrics will not be logged')
  97. for logger in loggers:
  98. if isinstance(logger, CSVLogger) or isinstance(logger, GenericCSVLogger):
  99. #print(logger)
  100. logger.all_metrics = self.metrics
  101. if self.collect_all:
  102. # for each curve collect all emitted values.
  103. # dictionary key is full metric name.
  104. # Dictionary value is a tuple of two lists.
  105. # first list gathers x values (indices representing
  106. # time steps at which the corresponding metric value
  107. # has been emitted)
  108. # second list gathers metric values
  109. self.all_metric_results = defaultdict(lambda: ([], []))
  110. # Dictionary of last values emitted. Dictionary key
  111. # is the full metric name, while dictionary value is
  112. # metric value.
  113. self.last_metric_results = {}
  114. self._active = True
  115. """If True, no metrics will be collected."""
  116. @property
  117. def active(self):
  118. return self._active
  119. @active.setter
  120. def active(self, value):
  121. assert value is True or value is False, \
  122. "Active must be set as either True or False"
  123. self._active = value
  124. def _update_metrics(self, strategy: 'BaseStrategy', callback: str):
  125. if not self._active:
  126. return []
  127. metric_values = []
  128. non_none_counter = 0
  129. for metric in self.metrics:
  130. metric_result = getattr(metric, callback)(strategy)
  131. if isinstance(metric_result, Sequence):
  132. non_none_counter +=1
  133. metric_values += list(metric_result)
  134. elif metric_result is not None:
  135. non_none_counter +=1
  136. metric_values.append(metric_result)
  137. for metric_value in metric_values:
  138. name = metric_value.name
  139. x = metric_value.x_plot
  140. val = metric_value.value
  141. if self.collect_all:
  142. self.all_metric_results[name][0].append(x)
  143. self.all_metric_results[name][1].append(val)
  144. self.last_metric_results[name] = val
  145. for logger in self.loggers:
  146. getattr(logger, callback)(strategy, metric_values)
  147. return metric_values
  148. def get_last_metrics(self):
  149. """
  150. Return a shallow copy of dictionary with metric names
  151. as keys and last metrics value as values.
  152. :return: a dictionary with full metric
  153. names as keys and last metric value as value.
  154. """
  155. return copy(self.last_metric_results)
  156. def get_all_metrics(self):
  157. """
  158. Return the dictionary of all collected metrics.
  159. This method should be called only when `collect_all` is set to True.
  160. :return: if `collect_all` is True, returns a dictionary
  161. with full metric names as keys and a tuple of two lists
  162. as value. The first list gathers x values (indices
  163. representing time steps at which the corresponding
  164. metric value has been emitted). The second list
  165. gathers metric values. a dictionary. If `collect_all`
  166. is False return an empty dictionary
  167. """
  168. if self.collect_all:
  169. return self.all_metric_results
  170. else:
  171. return {}
  172. def reset_last_metrics(self):
  173. """
  174. Set the dictionary storing last value for each metric to be
  175. empty dict.
  176. """
  177. self.last_metric_results = {}
  178. def before_training(self, strategy: 'BaseStrategy', **kwargs):
  179. self._update_metrics(strategy, 'before_training')
  180. def before_training_exp(self, strategy: 'BaseStrategy', **kwargs):
  181. self._update_metrics(strategy, 'before_training_exp')
  182. def before_train_dataset_adaptation(self, strategy: 'BaseStrategy',
  183. **kwargs):
  184. self._update_metrics(strategy, 'before_train_dataset_adaptation')
  185. def after_train_dataset_adaptation(self, strategy: 'BaseStrategy',
  186. **kwargs):
  187. self._update_metrics(strategy, 'after_train_dataset_adaptation')
  188. def before_training_epoch(self, strategy: 'BaseStrategy', **kwargs):
  189. self._update_metrics(strategy, 'before_training_epoch')
  190. def before_training_iteration(self, strategy: 'BaseStrategy', **kwargs):
  191. self._update_metrics(strategy, 'before_training_iteration')
  192. def before_forward(self, strategy: 'BaseStrategy', **kwargs):
  193. self._update_metrics(strategy, 'before_forward')
  194. def after_forward(self, strategy: 'BaseStrategy', **kwargs):
  195. self._update_metrics(strategy, 'after_forward')
  196. def before_backward(self, strategy: 'BaseStrategy', **kwargs):
  197. self.update_metrics = self._update_metrics(strategy, 'before_backward')
  198. def after_backward(self, strategy: 'BaseStrategy', **kwargs):
  199. self._update_metrics(strategy, 'after_backward')
  200. def after_training_iteration(self, strategy: 'BaseStrategy', **kwargs):
  201. self._update_metrics(strategy, 'after_training_iteration')
  202. def before_update(self, strategy: 'BaseStrategy', **kwargs):
  203. self._update_metrics(strategy, 'before_update')
  204. def after_update(self, strategy: 'BaseStrategy', **kwargs):
  205. self._update_metrics(strategy, 'after_update')
  206. def after_training_epoch(self, strategy: 'BaseStrategy', **kwargs):
  207. self._update_metrics(strategy, 'after_training_epoch')
  208. def after_training_exp(self, strategy: 'BaseStrategy', **kwargs):
  209. self._update_metrics(strategy, 'after_training_exp')
  210. def after_training(self, strategy: 'BaseStrategy', **kwargs):
  211. self._update_metrics(strategy, 'after_training')
  212. def before_eval(self, strategy: 'BaseStrategy', **kwargs):
  213. self._update_metrics(strategy, 'before_eval')
  214. msgw = "Evaluation stream is not equal to the complete test stream. " \
  215. "This may result in inconsistent metrics. Use at your own risk."
  216. msge = "Stream provided to `eval` must be the same of the entire " \
  217. "evaluation stream."
  218. if self.benchmark is not None:
  219. for i, exp in enumerate(self.complete_test_stream):
  220. try:
  221. current_exp = strategy.current_eval_stream[i]
  222. if exp.current_experience != current_exp.current_experience:
  223. if not self.suppress_warnings:
  224. if self.strict_checks:
  225. raise ValueError(msge)
  226. else:
  227. warnings.warn(msgw)
  228. except IndexError:
  229. if self.strict_checks:
  230. raise ValueError(msge)
  231. else:
  232. warnings.warn(msgw)
  233. def before_eval_dataset_adaptation(self, strategy: 'BaseStrategy',
  234. **kwargs):
  235. self._update_metrics(strategy, 'before_eval_dataset_adaptation')
  236. def after_eval_dataset_adaptation(self, strategy: 'BaseStrategy', **kwargs):
  237. self._update_metrics(strategy, 'after_eval_dataset_adaptation')
  238. def before_eval_exp(self, strategy: 'BaseStrategy', **kwargs):
  239. self._update_metrics(strategy, 'before_eval_exp')
  240. def after_eval_exp(self, strategy: 'BaseStrategy', **kwargs):
  241. self._update_metrics(strategy, 'after_eval_exp')
  242. def after_eval(self, strategy: 'BaseStrategy', **kwargs):
  243. self._update_metrics(strategy, 'after_eval')
  244. if strategy.last_eval:
  245. plt.figure(figsize=(27,18))
  246. tick_marks = np.arange(strategy.model.num_classes)
  247. confusion_data_dict = {'tick_marks': tick_marks}
  248. title = 'Acc: '+ str(np.round(strategy.top1_acc*100, 2))+' ClsAvgAcc: '+str(np.round(strategy.cls_top1_avg_acc*100, 2))
  249. title = title +'\n'+ 'SeqAcc: '+str(np.round(strategy.seq_top1_acc*100, 2)) + ' SeqClsAcc: '+ str(np.round(strategy.seq_cls_top1_avg_acc*100, 2))
  250. title = title + ' SeqAnyAcc'+str(np.round(strategy.seq_any_acc*100, 2))
  251. cmt = np.zeros((strategy.model.num_classes,strategy.model.num_classes))
  252. strategy.all_preds = strategy.all_preds.cpu()
  253. strategy.all_targets = strategy.all_targets.cpu()
  254. if strategy.experience.origin_stream.name =='validation':
  255. file_name = '_after_exp'+str(strategy.current_train_exp_seen)+'_validation'
  256. class_acc_dict = strategy.val_cls_acc_dict
  257. if strategy.experience.origin_stream.name =='test':
  258. file_name = '_after_exp'+str(strategy.current_train_exp_seen)+'_test'
  259. class_acc_dict = strategy.test_cls_acc_dict
  260. if strategy.experience.origin_stream.name =='train':
  261. file_name = '_after_exp'+str(strategy.current_train_exp_seen)+'_train'
  262. class_acc_dict = strategy.train_cls_acc_dict
  263. confusion_data_dict['title'] = title
  264. confusion_data_dict['top1_acc'] = strategy.top1_acc
  265. confusion_data_dict['cls_top1_avg_acc'] = strategy.cls_top1_avg_acc
  266. confusion_data_dict['seq_top1_acc'] = strategy.seq_top1_acc
  267. confusion_data_dict['seq_cls_top1_avg_acc'] = strategy.seq_cls_top1_avg_acc
  268. confusion_data_dict['seq_any_acc'] = strategy.seq_any_acc
  269. confusion_data_dict['all_preds'] = strategy.all_preds
  270. confusion_data_dict['all_targets'] = strategy.all_targets
  271. confusion_data_dict['label_dict'] = strategy.label_dict
  272. confusion_data_dict['cls_acc_dict'] = class_acc_dict
  273. confusion_data_dict['unique_train_cls_dict'] = strategy.unique_train_cls_dict
  274. confusion_data_dict['total_train_cls_dict'] = strategy.total_train_cls_dict
  275. confusion_data_dict['total_test_cls_dict'] = strategy.total_test_cls_dict
  276. confusion_data_dict['total_validation_cls_dict'] = strategy.total_validation_cls_dict
  277. confusion_data_dict['cumulative_dataset_paths']=strategy.cumulative_dataset_paths
  278. confusion_data_dict['rehearsal_indicies_picked']=strategy.rehearsal_indicies_picked # corresponding to cumulative datset paths
  279. confusion_data_dict['mir_losses_dict'] = strategy.mir_losses_dict
  280. confusion_data_dict['memory_dataset_paths']= strategy.memory_dataset_paths
  281. confusion_data_dict['memory_dataset_targets'] = strategy.memory_dataset_targets
  282. for pl, tl in zip(strategy.all_preds, strategy.all_targets):
  283. cmt[int(tl), int(pl)] = cmt[int(tl), int(pl)] + 1
  284. cmt = cmt.astype('float') / cmt.sum(axis=1)[:, np.newaxis]
  285. classes = []
  286. class_labels = []
  287. for i in range(strategy.model.num_classes):
  288. cls_acc = -1
  289. if isinstance(class_acc_dict[i].result(), dict) and (0 in class_acc_dict[i].result().keys()):
  290. if isinstance(class_acc_dict[i].result()[0], dict):
  291. if 0 in class_acc_dict[i].result()[0].keys():
  292. cls_acc = class_acc_dict[i].result()[0][0]
  293. elif isinstance(class_acc_dict[i].result()[0], float):
  294. cls_acc = class_acc_dict[i].result()[0]
  295. if i in strategy.label_dict.keys():
  296. classes.append(strategy.label_dict[i])
  297. else:
  298. classes.append('Unknown Class')
  299. cls_label = strategy.label_dict[i] +''+' cls_acc: '+str(np.round(cls_acc*100,2)) +'\ntrain/cumu/test/validation: '
  300. cls_label = cls_label +str(self.get_int_value(strategy.unique_train_cls_dict[i]))+'/'+str(self.get_int_value(strategy.total_train_cls_dict[i]))+'/'+str(self.get_int_value(strategy.total_test_cls_dict[i]))+'/'+str(self.get_int_value(strategy.total_validation_cls_dict[i]))
  301. class_labels.append(cls_label)
  302. plt.imshow(cmt, interpolation='nearest', cmap=plt.cm.Blues)
  303. plt.title(title, fontsize=21)
  304. plt.colorbar()
  305. plt.xticks(tick_marks, classes, rotation=90, fontsize=18)
  306. plt.yticks(tick_marks, class_labels, fontsize=18)
  307. plt.ylabel('True label', fontsize=18)
  308. plt.xlabel('Predicted label', fontsize=18)
  309. print(strategy.log_dir)
  310. plt.tight_layout()
  311. datetime_stap = str(datetime.datetime.now()).replace(' ','_').replace(':','-').replace('.','-')[:-4]
  312. plt.savefig(str(strategy.log_dir)+datetime_stap+file_name+'_confusion.png', facecolor='white')
  313. plt.close()
  314. with open(str(strategy.log_dir)+datetime_stap+file_name+'_confusion_data_dict.pkl', 'wb') as handle:
  315. pkl.dump(confusion_data_dict, handle, protocol=pkl.HIGHEST_PROTOCOL)
  316. def get_int_value(self, value):
  317. if torch.is_tensor(value):
  318. return value.item()
  319. return int(value)
  320. def before_eval_iteration(self, strategy: 'BaseStrategy', **kwargs):
  321. self._update_metrics(strategy, 'before_eval_iteration')
  322. def before_eval_forward(self, strategy: 'BaseStrategy', **kwargs):
  323. self._update_metrics(strategy, 'before_eval_forward')
  324. def after_eval_forward(self, strategy: 'BaseStrategy', **kwargs):
  325. self._update_metrics(strategy, 'after_eval_forward')
  326. def after_eval_iteration(self, strategy: 'BaseStrategy', **kwargs):
  327. self._update_metrics(strategy, 'after_eval_iteration')
  328. default_logger = EvaluationPlugin(
  329. accuracy_metrics(minibatch=False, epoch=True, experience=True, stream=True),
  330. loss_metrics(minibatch=False, epoch=True, experience=True, stream=True),
  331. loggers=[InteractiveLogger()],
  332. suppress_warnings=True)
  333. __all__ = [
  334. 'EvaluationPlugin',
  335. 'default_logger'
  336. ]