mean_scores.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. """
  2. This metric was described in the IL2M paper:
  3. E. Belouadah and A. Popescu,
  4. "IL2M: Class Incremental Learning With Dual Memory,"
  5. 2019 IEEE/CVF International Conference on Computer Vision (ICCV),
  6. 2019, pp. 583-592, doi: 10.1109/ICCV.2019.00067.
  7. It selects the scores of the true class and then average them for past and new
  8. classes.
  9. """
  10. from abc import ABC
  11. from collections import defaultdict
  12. from typing import Callable, Dict, Set, TYPE_CHECKING, List, Optional
  13. import torch
  14. from matplotlib.axes import Axes
  15. from matplotlib.figure import Figure
  16. from matplotlib.pyplot import subplots
  17. from torch import Tensor, arange
  18. from avalanche.evaluation import Metric, PluginMetric
  19. from avalanche.evaluation.metric_utils import get_metric_name
  20. from avalanche.evaluation.metrics import Mean
  21. from avalanche.evaluation.metric_results import MetricValue, AlternativeValues
  22. try:
  23. from typing import Literal
  24. except ImportError:
  25. from typing_extensions import Literal
  26. if TYPE_CHECKING:
  27. from avalanche.training.strategies import BaseStrategy
  28. from avalanche.evaluation.metric_results import MetricResult
  29. LabelCat = Literal["new", "old"]
  30. class MeanScores(Metric):
  31. """
  32. Average the scores of the true class by label
  33. """
  34. def __init__(self):
  35. super().__init__()
  36. self.label2mean: Dict[int, Mean] = defaultdict(Mean)
  37. self.reset()
  38. def reset(self) -> None:
  39. self.label2mean = defaultdict(Mean)
  40. @torch.no_grad()
  41. def update(self, predicted_y: Tensor, true_y: Tensor):
  42. assert (
  43. len(predicted_y.size()) == 2
  44. ), "Predictions need to be logits or scores, not labels"
  45. if len(true_y.size()) == 2:
  46. true_y = true_y.argmax(axis=1)
  47. scores = predicted_y[arange(len(true_y)), true_y]
  48. for score, label in zip(scores.tolist(), true_y.tolist()):
  49. self.label2mean[label].update(score)
  50. def result(self) -> Dict[int, float]:
  51. return {label: m.result() for label, m in self.label2mean.items()}
  52. class MeanNewOldScores(MeanScores):
  53. """
  54. Average the scores of the true class by old and new classes
  55. """
  56. def __init__(self):
  57. super().__init__()
  58. self.new_classes: Set[int] = set()
  59. def reset(self) -> None:
  60. super().reset()
  61. self.new_classes = set()
  62. def update_new_classes(self, new_classes: Set[int]):
  63. self.new_classes.update(new_classes)
  64. @property
  65. def old_classes(self) -> Set[int]:
  66. return set(self.label2mean) - self.new_classes
  67. def result(self) -> Dict[LabelCat, float]:
  68. # print(self.new_classes, self.label2mean)
  69. rv = {
  70. "new": sum(
  71. (self.label2mean[label] for label in self.new_classes),
  72. start=Mean(),
  73. ).result()
  74. }
  75. if not self.old_classes:
  76. return rv
  77. rv["old"] = sum(
  78. (self.label2mean[label] for label in self.old_classes),
  79. start=Mean(),
  80. ).result()
  81. return rv
  82. def default_mean_scores_image_creator(
  83. label2step2mean_scores: Dict[LabelCat, Dict[int, float]]
  84. ) -> Figure:
  85. """
  86. Default function to create an image of the evolution of the scores of the
  87. true class, averaged by new and old classes.
  88. :param label2step2mean_scores: A dictionary that, for each label category
  89. ("old" and "new") contains a dictionary of mean scores indexed by the
  90. step of the observation.
  91. :return: The figure containing the graphs.
  92. """
  93. fig, ax = subplots()
  94. ax: Axes
  95. markers = "*o"
  96. for marker, (label, step2mean_scores) in zip(
  97. markers, label2step2mean_scores.items()
  98. ):
  99. ax.plot(
  100. step2mean_scores.keys(),
  101. step2mean_scores.values(),
  102. marker,
  103. label=label,
  104. )
  105. ax.legend(loc="lower left")
  106. ax.set_xlabel("step")
  107. ax.set_ylabel("mean score")
  108. fig.tight_layout()
  109. return fig
  110. MeanScoresImageCreator = Callable[[Dict[LabelCat, Dict[int, int]]], Figure]
  111. class MeanScoresPluginMetricABC(PluginMetric, ABC):
  112. """
  113. Base class for the plugins that show the scores of the true class, averaged
  114. by new and old classes.
  115. :param image_creator: The function to use to create an image of the history
  116. of the mean scores grouped by old and new classes
  117. """
  118. def __init__(
  119. self,
  120. image_creator: Optional[
  121. MeanScoresImageCreator
  122. ] = default_mean_scores_image_creator,
  123. ):
  124. super().__init__()
  125. self.mean_scores = MeanNewOldScores()
  126. self.image_creator = image_creator
  127. self.label_cat2step2mean: Dict[
  128. LabelCat, Dict[int, float]
  129. ] = defaultdict(dict)
  130. def reset(self) -> None:
  131. self.mean_scores.reset()
  132. def update_new_classes(self, strategy: "BaseStrategy"):
  133. self.mean_scores.update_new_classes(
  134. strategy.experience.classes_in_this_experience
  135. )
  136. def update(self, strategy: "BaseStrategy"):
  137. self.mean_scores.update(
  138. predicted_y=strategy.mb_output, true_y=strategy.mb_y
  139. )
  140. def result(self) -> Dict[LabelCat, float]:
  141. return self.mean_scores.result()
  142. def _package_result(self, strategy: "BaseStrategy") -> "MetricResult":
  143. label_cat2mean_score: Dict[LabelCat, float] = self.result()
  144. for label_cat, m in label_cat2mean_score.items():
  145. self.label_cat2step2mean[label_cat][self.global_it_counter] = m
  146. base_metric_name = get_metric_name(
  147. self, strategy, add_experience=False, add_task=False
  148. )
  149. rv = [
  150. MetricValue(
  151. self,
  152. name=base_metric_name + f"/{label_cat}_classes",
  153. value=m,
  154. x_plot=self.global_it_counter,
  155. )
  156. for label_cat, m in label_cat2mean_score.items()
  157. ]
  158. if "old" in label_cat2mean_score and "new" in label_cat2mean_score:
  159. rv.append(
  160. MetricValue(
  161. self,
  162. name=base_metric_name + f"/new_old_diff",
  163. value=label_cat2mean_score["new"]
  164. - label_cat2mean_score["old"],
  165. x_plot=self.global_it_counter,
  166. )
  167. )
  168. if self.image_creator is not None:
  169. rv.append(
  170. MetricValue(
  171. self,
  172. name=base_metric_name,
  173. value=AlternativeValues(
  174. self.image_creator(self.label_cat2step2mean),
  175. self.label_cat2step2mean,
  176. ),
  177. x_plot=self.global_it_counter,
  178. )
  179. )
  180. return rv
  181. def __str__(self):
  182. return "MeanScores"
  183. class MeanScoresTrainPluginMetric(MeanScoresPluginMetricABC):
  184. """
  185. Plugin to show the scores of the true class during the lasts training
  186. epochs of each experience, averaged by new and old classes.
  187. """
  188. def before_training_epoch(self, strategy: "BaseStrategy") -> None:
  189. self.reset()
  190. self.update_new_classes(strategy)
  191. def after_training_iteration(self, strategy: "BaseStrategy") -> None:
  192. if strategy.epoch == strategy.train_epochs - 1:
  193. self.update(strategy)
  194. super().after_training_iteration(strategy)
  195. def after_training_epoch(self, strategy: "BaseStrategy") -> "MetricResult":
  196. if strategy.epoch == strategy.train_epochs - 1:
  197. return self._package_result(strategy)
  198. class MeanScoresEvalPluginMetric(MeanScoresPluginMetricABC):
  199. """
  200. Plugin to show the scores of the true class during evaluation, averaged by
  201. new and old classes.
  202. """
  203. def before_training(self, strategy: "BaseStrategy") -> None:
  204. self.reset()
  205. def before_training_exp(self, strategy: "BaseStrategy") -> None:
  206. self.update_new_classes(strategy)
  207. def after_eval_iteration(self, strategy: "BaseStrategy") -> None:
  208. self.update(strategy)
  209. super().after_eval_iteration(strategy)
  210. def after_eval(self, strategy: "BaseStrategy") -> "MetricResult":
  211. return self._package_result(strategy)
  212. def mean_scores_metrics(
  213. *,
  214. on_train: bool = True,
  215. on_eval: bool = True,
  216. image_creator: Optional[
  217. MeanScoresImageCreator
  218. ] = default_mean_scores_image_creator,
  219. ) -> List[PluginMetric]:
  220. """
  221. Helper to create plugins to show the scores of the true class, averaged by
  222. new and old classes. The plugins are available during training (for the
  223. last epoch of each experience) and evaluation.
  224. :param on_train: If True the train plugin is created
  225. :param on_eval: If True the eval plugin is created
  226. :param image_creator: The function to use to create an image of the history
  227. of the mean scores grouped by old and new classes
  228. :return: The list of plugins that were specified
  229. """
  230. plugins = []
  231. if on_eval:
  232. plugins.append(MeanScoresEvalPluginMetric(image_creator=image_creator))
  233. if on_train:
  234. plugins.append(MeanScoresTrainPluginMetric(image_creator=image_creator))
  235. return plugins
  236. __all__ = [
  237. mean_scores_metrics,
  238. MeanScoresTrainPluginMetric,
  239. MeanScoresEvalPluginMetric,
  240. MeanScores,
  241. MeanNewOldScores,
  242. ]