images_samples.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from typing import List, TYPE_CHECKING, Tuple
  2. from torch import Tensor
  3. from torch.utils.data import DataLoader
  4. from torchvision.transforms import ToTensor
  5. from torchvision.utils import make_grid
  6. from avalanche.evaluation.metric_definitions import PluginMetric
  7. from avalanche.evaluation.metric_results import (
  8. MetricResult,
  9. TensorImage,
  10. MetricValue,
  11. )
  12. from avalanche.evaluation.metric_utils import get_metric_name
  13. try:
  14. from typing import Literal
  15. except ImportError:
  16. from typing_extensions import Literal
  17. if TYPE_CHECKING:
  18. from avalanche.training.strategies import BaseStrategy
  19. class ImagesSamplePlugin(PluginMetric):
  20. """
  21. A metric used to sample images at random.
  22. No data augmentation is shown.
  23. Only images in strategy.adapted dataset are used. Images added in the
  24. dataloader (like the replay plugins do) are missed.
  25. :param n_rows: The numbers of raws to use in the grid of images.
  26. :param n_cols: The numbers of columns to use in the grid of images.
  27. :param group: If True, images will be grouped by (task, label)
  28. :param mode: The plugin can be used at train or eval time.
  29. :return: The corresponding plugins.
  30. """
  31. def __init__(
  32. self,
  33. *,
  34. mode: Literal["train", "eval"],
  35. n_cols: int,
  36. n_rows: int,
  37. group: bool = True,
  38. ):
  39. super().__init__()
  40. self.group = group
  41. self.n_rows = n_rows
  42. self.n_cols = n_cols
  43. self.mode = mode
  44. self.images: List[Tensor] = []
  45. self.n_wanted_images = self.n_cols * self.n_rows
  46. def after_train_dataset_adaptation(
  47. self, strategy: "BaseStrategy"
  48. ) -> "MetricResult":
  49. if self.mode == "train":
  50. return self.make_grid_sample(strategy)
  51. def after_eval_dataset_adaptation(
  52. self, strategy: "BaseStrategy"
  53. ) -> "MetricResult":
  54. if self.mode == "eval":
  55. return self.make_grid_sample(strategy)
  56. def make_grid_sample(self, strategy: "BaseStrategy") -> "MetricResult":
  57. self.load_sorted_images(strategy)
  58. return [
  59. MetricValue(
  60. self,
  61. name=get_metric_name(
  62. self,
  63. strategy,
  64. add_experience=self.mode == "eval",
  65. add_task=True,
  66. ),
  67. value=TensorImage(
  68. make_grid(
  69. list(self.images), normalize=False, nrow=self.n_cols
  70. )
  71. ),
  72. x_plot=self.get_global_counter(),
  73. )
  74. ]
  75. def load_sorted_images(self, strategy: "BaseStrategy"):
  76. self.reset()
  77. self.images, labels, tasks = self.load_data(strategy)
  78. if self.group:
  79. self.sort_images(labels, tasks)
  80. def load_data(
  81. self, strategy: "BaseStrategy"
  82. ) -> Tuple[List[Tensor], List[int], List[int]]:
  83. dataloader = self.make_dataloader(strategy)
  84. images, labels, tasks = [], [], []
  85. for batch_images, batch_labels, batch_tasks in dataloader:
  86. n_missing_images = self.n_wanted_images - len(images)
  87. labels.extend(batch_labels[:n_missing_images].tolist())
  88. tasks.extend(batch_tasks[:n_missing_images].tolist())
  89. images.extend(batch_images[:n_missing_images])
  90. if len(images) == self.n_wanted_images:
  91. return images, labels, tasks
  92. def sort_images(self, labels: List[int], tasks: List[int]):
  93. self.images = [
  94. image
  95. for task, label, image in sorted(
  96. zip(tasks, labels, self.images), key=lambda t: (t[0], t[1]),
  97. )
  98. ]
  99. def make_dataloader(self, strategy: "BaseStrategy") -> DataLoader:
  100. return DataLoader(
  101. dataset=strategy.adapted_dataset.replace_transforms(
  102. transform=ToTensor(), target_transform=None,
  103. ),
  104. batch_size=min(strategy.eval_mb_size, self.n_wanted_images),
  105. shuffle=True,
  106. )
  107. def reset(self) -> None:
  108. self.images = []
  109. def result(self) -> List[Tensor]:
  110. return self.images
  111. def __str__(self):
  112. return "images"
  113. def images_samples_metrics(
  114. *,
  115. n_rows: int = 8,
  116. n_cols: int = 8,
  117. group: bool = True,
  118. on_train: bool = True,
  119. on_eval: bool = False,
  120. ) -> List[PluginMetric]:
  121. """
  122. Create the plugins to log some images samples in grids.
  123. No data augmentation is shown.
  124. Only images in strategy.adapted dataset are used. Images added in the
  125. dataloader (like the replay plugins do) are missed.
  126. :param n_rows: The numbers of raws to use in the grid of images.
  127. :param n_cols: The numbers of columns to use in the grid of images.
  128. :param group: If True, images will be grouped by (task, label)
  129. :param on_train: If True, will emit some images samples during training.
  130. :param on_eval: If True, will emit some images samples during evaluation.
  131. :return: The corresponding plugins.
  132. """
  133. plugins = []
  134. if on_eval:
  135. plugins.append(
  136. ImagesSamplePlugin(
  137. mode="eval", n_rows=n_rows, n_cols=n_cols, group=group
  138. )
  139. )
  140. if on_train:
  141. plugins.append(
  142. ImagesSamplePlugin(
  143. mode="train", n_rows=n_rows, n_cols=n_cols, group=group
  144. )
  145. )
  146. return plugins
  147. __all__ = [
  148. images_samples_metrics,
  149. ImagesSamplePlugin,
  150. ]