generic_definitions.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 12-05-2020 #
  7. # Author(s): Lorenzo Pellegrini #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. from abc import abstractmethod
  12. import warnings
  13. try:
  14. from typing import TypeVar, Tuple, List, Protocol, runtime_checkable, \
  15. Sequence, Any, Union, Iterable, Generic
  16. except ImportError:
  17. from typing import TypeVar, Tuple, List, Sequence, Any, Union, Iterable, \
  18. Generic
  19. from typing_extensions import Protocol, runtime_checkable
  20. from avalanche.benchmarks.utils import AvalancheDataset
  21. TScenario = TypeVar('TScenario')
  22. TExperience = TypeVar('TExperience', bound='Experience')
  23. TScenarioStream = TypeVar('TScenarioStream', bound='ScenarioStream')
  24. @runtime_checkable
  25. class Experience(Protocol[TScenario, TScenarioStream]):
  26. """
  27. Definition of an experience. An experience contains a set of patterns
  28. which has become available at a particular time instant. The content and
  29. size of an Experience is defined by the specific benchmark that creates the
  30. IExperience instance.
  31. For instance, an experience of a New Classes scenario will contain all
  32. patterns belonging to a subset of classes of the original training set. An
  33. experience of a New Instance scenario will contain patterns from previously
  34. seen classes.
  35. Experiences of Single Incremental Task (a.k.a. task-free) scenarios are
  36. usually called "batches" while in Multi Task scenarios an Experience is
  37. usually associated to a "task". Finally, in a Multi Incremental Task
  38. scenario the Experience may be composed by patterns from different tasks.
  39. """
  40. origin_stream: TScenarioStream
  41. """
  42. A reference to the original stream from which this experience was obtained.
  43. """
  44. benchmark: TScenario
  45. """
  46. A reference to the benchmark.
  47. """
  48. current_experience: int
  49. """
  50. This is an incremental, 0-indexed, value used to keep track of the position
  51. of current experience in the original stream.
  52. Beware that this value only describes the experience position in the
  53. original stream and may be unrelated to the order in which the strategy will
  54. encounter experiences.
  55. """
  56. dataset: AvalancheDataset
  57. """
  58. The dataset containing the patterns available in this experience.
  59. """
  60. @property
  61. @abstractmethod
  62. def task_labels(self) -> List[int]:
  63. """
  64. This list will contain the unique task labels of the patterns contained
  65. in this experience. In the most common scenarios this will be a list
  66. with a single value. Note: for scenarios that don't produce task labels,
  67. a placeholder task label value like 0 is usually set to each pattern
  68. (see the description of the originating scenario for details).
  69. """
  70. ...
  71. @property
  72. @abstractmethod
  73. def task_label(self) -> int:
  74. """
  75. The task label. This value will never have value "None". However,
  76. for scenarios that don't produce task labels a placeholder value like 0
  77. is usually set. Beware that this field is meant as a shortcut to obtain
  78. a unique task label: it assumes that only patterns labeled with a
  79. single task label are present. If this experience contains patterns from
  80. multiple tasks, accessing this property will result in an exception.
  81. """
  82. ...
  83. @property
  84. def scenario(self) -> TScenario:
  85. """ This property is DEPRECATED, use self.benchmark instead."""
  86. warnings.warn(
  87. 'Using self.scenario is deprecated in Experience. '
  88. 'Consider using self.benchmark instead.', stacklevel=2)
  89. return self.benchmark
  90. class ScenarioStream(Protocol[TScenario, TExperience]):
  91. """
  92. A scenario stream describes a sequence of incremental experiences.
  93. Experiences are described as :class:`IExperience` instances. They contain a
  94. set of patterns which has become available at a particular time instant
  95. along with any optional, scenario specific, metadata.
  96. Most scenario expose two different streams: the training stream and the test
  97. stream.
  98. """
  99. name: str
  100. """
  101. The name of the stream.
  102. """
  103. benchmark: TScenario
  104. """
  105. A reference to the scenario this stream belongs to.
  106. """
  107. @property
  108. def scenario(self) -> TScenario:
  109. """ This property is DEPRECATED, use self.benchmark instead."""
  110. warnings.warn(
  111. 'Using self.scenario is deprecated ScenarioStream. '
  112. 'Consider using self.benchmark instead.', stacklevel=2)
  113. return self.benchmark
  114. def __getitem__(self: TScenarioStream,
  115. experience_idx: Union[int, slice, Iterable[int]]) \
  116. -> Union[TExperience, TScenarioStream]:
  117. """
  118. Gets an experience given its experience index (or a stream slice given
  119. the experience order).
  120. :param experience_idx: An int describing the experience index or an
  121. iterable/slice object describing a slice of this stream.
  122. :return: The Experience instance associated to the given experience
  123. index or a sliced stream instance.
  124. """
  125. ...
  126. def __len__(self) -> int:
  127. """
  128. Used to get the length of this stream (the amount of experiences).
  129. :return: The amount of experiences in this stream.
  130. """
  131. ...
  132. __all__ = [
  133. 'Experience',
  134. 'TExperience',
  135. 'TScenario',
  136. 'ScenarioStream',
  137. 'TScenarioStream'
  138. ]