benchmark_generators.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 16-04-2021 #
  7. # Author(s): Lorenzo Pellegrini #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """ In this module the high-level benchmark generators are listed. They are
  12. based on the methods already implemented in the "scenario" module. For the
  13. specific generators we have: "New Classes" (NC) and "New Instances" (NI); For
  14. the generic ones: filelist_benchmark, tensors_benchmark, dataset_benchmark
  15. and paths_benchmark.
  16. """
  17. from functools import partial
  18. from itertools import tee
  19. from typing import Sequence, Optional, Dict, Union, Any, List, Callable, Set, \
  20. Tuple, Iterable, Generator
  21. import torch
  22. from avalanche.benchmarks import GenericCLScenario, Experience, \
  23. GenericScenarioStream
  24. from avalanche.benchmarks.scenarios.generic_benchmark_creation import *
  25. from avalanche.benchmarks.scenarios.generic_cl_scenario import \
  26. TStreamsUserDict, StreamUserDef
  27. from avalanche.benchmarks.scenarios.new_classes.nc_scenario import \
  28. NCScenario
  29. from avalanche.benchmarks.scenarios.new_instances.ni_scenario import NIScenario
  30. from avalanche.benchmarks.utils import concat_datasets_sequentially
  31. from avalanche.benchmarks.utils.avalanche_dataset import SupportedDataset, \
  32. AvalancheDataset, AvalancheDatasetType, AvalancheSubset
  33. def nc_benchmark(
  34. train_dataset: Union[
  35. Sequence[SupportedDataset], SupportedDataset],
  36. test_dataset: Union[
  37. Sequence[SupportedDataset], SupportedDataset],
  38. n_experiences: int,
  39. task_labels: bool,
  40. *,
  41. shuffle: bool = True,
  42. seed: Optional[int] = None,
  43. fixed_class_order: Sequence[int] = None,
  44. per_exp_classes: Dict[int, int] = None,
  45. class_ids_from_zero_from_first_exp: bool = False,
  46. class_ids_from_zero_in_each_exp: bool = False,
  47. one_dataset_per_exp: bool = False,
  48. train_transform=None,
  49. eval_transform=None,
  50. reproducibility_data: Dict[str, Any] = None) -> NCScenario:
  51. """
  52. This is the high-level benchmark instances generator for the
  53. "New Classes" (NC) case. Given a sequence of train and test datasets creates
  54. the continual stream of data as a series of experiences. Each experience
  55. will contain all the instances belonging to a certain set of classes and a
  56. class won't be assigned to more than one experience.
  57. This is the reference helper function for creating instances of Class- or
  58. Task-Incremental benchmarks.
  59. The ``task_labels`` parameter determines if each incremental experience has
  60. an increasing task label or if, at the contrary, a default task label "0"
  61. has to be assigned to all experiences. This can be useful when
  62. differentiating between Single-Incremental-Task and Multi-Task scenarios.
  63. There are other important parameters that can be specified in order to tweak
  64. the behaviour of the resulting benchmark. Please take a few minutes to read
  65. and understand them as they may save you a lot of work.
  66. This generator features a integrated reproducibility mechanism that allows
  67. the user to store and later re-load a benchmark. For more info see the
  68. ``reproducibility_data`` parameter.
  69. :param train_dataset: A list of training datasets, or a single dataset.
  70. :param test_dataset: A list of test datasets, or a single test dataset.
  71. :param n_experiences: The number of incremental experience. This is not used
  72. when using multiple train/test datasets with the ``one_dataset_per_exp``
  73. parameter set to True.
  74. :param task_labels: If True, each experience will have an ascending task
  75. label. If False, the task label will be 0 for all the experiences.
  76. :param shuffle: If True, the class (or experience) order will be shuffled.
  77. Defaults to True.
  78. :param seed: If ``shuffle`` is True and seed is not None, the class (or
  79. experience) order will be shuffled according to the seed. When None, the
  80. current PyTorch random number generator state will be used. Defaults to
  81. None.
  82. :param fixed_class_order: If not None, the class order to use (overrides
  83. the shuffle argument). Very useful for enhancing reproducibility.
  84. Defaults to None.
  85. :param per_exp_classes: Is not None, a dictionary whose keys are
  86. (0-indexed) experience IDs and their values are the number of classes
  87. to include in the respective experiences. The dictionary doesn't
  88. have to contain a key for each experience! All the remaining experiences
  89. will contain an equal amount of the remaining classes. The
  90. remaining number of classes must be divisible without remainder
  91. by the remaining number of experiences. For instance,
  92. if you want to include 50 classes in the first experience
  93. while equally distributing remaining classes across remaining
  94. experiences, just pass the "{0: 50}" dictionary as the
  95. per_experience_classes parameter. Defaults to None.
  96. :param class_ids_from_zero_from_first_exp: If True, original class IDs
  97. will be remapped so that they will appear as having an ascending
  98. order. For instance, if the resulting class order after shuffling
  99. (or defined by fixed_class_order) is [23, 34, 11, 7, 6, ...] and
  100. class_ids_from_zero_from_first_exp is True, then all the patterns
  101. belonging to class 23 will appear as belonging to class "0",
  102. class "34" will be mapped to "1", class "11" to "2" and so on.
  103. This is very useful when drawing confusion matrices and when dealing
  104. with algorithms with dynamic head expansion. Defaults to False.
  105. Mutually exclusive with the ``class_ids_from_zero_in_each_exp``
  106. parameter.
  107. :param class_ids_from_zero_in_each_exp: If True, original class IDs
  108. will be mapped to range [0, n_classes_in_exp) for each experience.
  109. Defaults to False. Mutually exclusive with the
  110. ``class_ids_from_zero_from_first_exp`` parameter.
  111. :param one_dataset_per_exp: available only when multiple train-test
  112. datasets are provided. If True, each dataset will be treated as a
  113. experience. Mutually exclusive with the ``per_experience_classes`` and
  114. ``fixed_class_order`` parameters. Overrides the ``n_experiences``
  115. parameter. Defaults to False.
  116. :param train_transform: The transformation to apply to the training data,
  117. e.g. a random crop, a normalization or a concatenation of different
  118. transformations (see torchvision.transform documentation for a
  119. comprehensive list of possible transformations). Defaults to None.
  120. :param eval_transform: The transformation to apply to the test data,
  121. e.g. a random crop, a normalization or a concatenation of different
  122. transformations (see torchvision.transform documentation for a
  123. comprehensive list of possible transformations). Defaults to None.
  124. :param reproducibility_data: If not None, overrides all the other
  125. benchmark definition options. This is usually a dictionary containing
  126. data used to reproduce a specific experiment. One can use the
  127. ``get_reproducibility_data`` method to get (and even distribute)
  128. the experiment setup so that it can be loaded by passing it as this
  129. parameter. In this way one can be sure that the same specific
  130. experimental setup is being used (for reproducibility purposes).
  131. Beware that, in order to reproduce an experiment, the same train and
  132. test datasets must be used. Defaults to None.
  133. :return: A properly initialized :class:`NCScenario` instance.
  134. """
  135. if class_ids_from_zero_from_first_exp and class_ids_from_zero_in_each_exp:
  136. raise ValueError('Invalid mutually exclusive options '
  137. 'class_ids_from_zero_from_first_exp and '
  138. 'classes_ids_from_zero_in_each_exp set at the '
  139. 'same time')
  140. if isinstance(train_dataset, list) or isinstance(train_dataset, tuple):
  141. # Multi-dataset setting
  142. if len(train_dataset) != len(test_dataset):
  143. raise ValueError('Train/test dataset lists must contain the '
  144. 'exact same number of datasets')
  145. if per_exp_classes and one_dataset_per_exp:
  146. raise ValueError(
  147. 'Both per_experience_classes and one_dataset_per_exp are'
  148. 'used, but those options are mutually exclusive')
  149. if fixed_class_order and one_dataset_per_exp:
  150. raise ValueError(
  151. 'Both fixed_class_order and one_dataset_per_exp are'
  152. 'used, but those options are mutually exclusive')
  153. seq_train_dataset, seq_test_dataset, mapping = \
  154. concat_datasets_sequentially(train_dataset, test_dataset)
  155. if one_dataset_per_exp:
  156. # If one_dataset_per_exp is True, each dataset will be treated as
  157. # a experience. In this benchmark, shuffle refers to the experience
  158. # order, not to the class one.
  159. fixed_class_order, per_exp_classes = \
  160. _one_dataset_per_exp_class_order(mapping, shuffle, seed)
  161. # We pass a fixed_class_order to the NCGenericScenario
  162. # constructor, so we don't need shuffling.
  163. shuffle = False
  164. seed = None
  165. # Overrides n_experiences (and per_experience_classes, already done)
  166. n_experiences = len(train_dataset)
  167. train_dataset, test_dataset = seq_train_dataset, seq_test_dataset
  168. transform_groups = dict(
  169. train=(train_transform, None),
  170. eval=(eval_transform, None)
  171. )
  172. # Datasets should be instances of AvalancheDataset
  173. train_dataset = AvalancheDataset(
  174. train_dataset,
  175. transform_groups=transform_groups,
  176. initial_transform_group='train',
  177. dataset_type=AvalancheDatasetType.CLASSIFICATION)
  178. test_dataset = AvalancheDataset(
  179. test_dataset,
  180. transform_groups=transform_groups,
  181. initial_transform_group='eval',
  182. dataset_type=AvalancheDatasetType.CLASSIFICATION)
  183. return NCScenario(train_dataset, test_dataset, n_experiences, task_labels,
  184. shuffle, seed, fixed_class_order, per_exp_classes,
  185. class_ids_from_zero_from_first_exp,
  186. class_ids_from_zero_in_each_exp,
  187. reproducibility_data)
  188. def ni_benchmark(
  189. train_dataset: Union[
  190. Sequence[SupportedDataset], SupportedDataset],
  191. test_dataset: Union[
  192. Sequence[SupportedDataset], SupportedDataset],
  193. n_experiences: int,
  194. *,
  195. task_labels: bool = False,
  196. shuffle: bool = True,
  197. seed: Optional[int] = None,
  198. balance_experiences: bool = False,
  199. min_class_patterns_in_exp: int = 0,
  200. fixed_exp_assignment: Optional[Sequence[Sequence[int]]] = None,
  201. train_transform=None,
  202. eval_transform=None,
  203. reproducibility_data: Optional[Dict[str, Any]] = None) \
  204. -> NIScenario:
  205. """
  206. This is the high-level benchmark instances generator for the
  207. "New Instances" (NI) case. Given a sequence of train and test datasets
  208. creates the continual stream of data as a series of experiences.
  209. This is the reference helper function for creating instances of
  210. Domain-Incremental benchmarks.
  211. The ``task_labels`` parameter determines if each incremental experience has
  212. an increasing task label or if, at the contrary, a default task label "0"
  213. has to be assigned to all experiences. This can be useful when
  214. differentiating between Single-Incremental-Task and Multi-Task scenarios.
  215. There are other important parameters that can be specified in order to tweak
  216. the behaviour of the resulting benchmark. Please take a few minutes to read
  217. and understand them as they may save you a lot of work.
  218. This generator features an integrated reproducibility mechanism that allows
  219. the user to store and later re-load a benchmark. For more info see the
  220. ``reproducibility_data`` parameter.
  221. :param train_dataset: A list of training datasets, or a single dataset.
  222. :param test_dataset: A list of test datasets, or a single test dataset.
  223. :param n_experiences: The number of experiences.
  224. :param task_labels: If True, each experience will have an ascending task
  225. label. If False, the task label will be 0 for all the experiences.
  226. :param shuffle: If True, patterns order will be shuffled.
  227. :param seed: A valid int used to initialize the random number generator.
  228. Can be None.
  229. :param balance_experiences: If True, pattern of each class will be equally
  230. spread across all experiences. If False, patterns will be assigned to
  231. experiences in a complete random way. Defaults to False.
  232. :param min_class_patterns_in_exp: The minimum amount of patterns of
  233. every class that must be assigned to every experience. Compatible with
  234. the ``balance_experiences`` parameter. An exception will be raised if
  235. this constraint can't be satisfied. Defaults to 0.
  236. :param fixed_exp_assignment: If not None, the pattern assignment
  237. to use. It must be a list with an entry for each experience. Each entry
  238. is a list that contains the indexes of patterns belonging to that
  239. experience. Overrides the ``shuffle``, ``balance_experiences`` and
  240. ``min_class_patterns_in_exp`` parameters.
  241. :param train_transform: The transformation to apply to the training data,
  242. e.g. a random crop, a normalization or a concatenation of different
  243. transformations (see torchvision.transform documentation for a
  244. comprehensive list of possible transformations). Defaults to None.
  245. :param eval_transform: The transformation to apply to the test data,
  246. e.g. a random crop, a normalization or a concatenation of different
  247. transformations (see torchvision.transform documentation for a
  248. comprehensive list of possible transformations). Defaults to None.
  249. :param reproducibility_data: If not None, overrides all the other
  250. benchmark definition options, including ``fixed_exp_assignment``.
  251. This is usually a dictionary containing data used to
  252. reproduce a specific experiment. One can use the
  253. ``get_reproducibility_data`` method to get (and even distribute)
  254. the experiment setup so that it can be loaded by passing it as this
  255. parameter. In this way one can be sure that the same specific
  256. experimental setup is being used (for reproducibility purposes).
  257. Beware that, in order to reproduce an experiment, the same train and
  258. test datasets must be used. Defaults to None.
  259. :return: A properly initialized :class:`NIScenario` instance.
  260. """
  261. seq_train_dataset, seq_test_dataset = train_dataset, test_dataset
  262. if isinstance(train_dataset, list) or isinstance(train_dataset, tuple):
  263. if len(train_dataset) != len(test_dataset):
  264. raise ValueError('Train/test dataset lists must contain the '
  265. 'exact same number of datasets')
  266. seq_train_dataset, seq_test_dataset, _ = \
  267. concat_datasets_sequentially(train_dataset, test_dataset)
  268. transform_groups = dict(
  269. train=(train_transform, None),
  270. eval=(eval_transform, None)
  271. )
  272. # Datasets should be instances of AvalancheDataset
  273. seq_train_dataset = AvalancheDataset(
  274. seq_train_dataset,
  275. transform_groups=transform_groups,
  276. initial_transform_group='train',
  277. dataset_type=AvalancheDatasetType.CLASSIFICATION)
  278. seq_test_dataset = AvalancheDataset(
  279. seq_test_dataset,
  280. transform_groups=transform_groups,
  281. initial_transform_group='eval',
  282. dataset_type=AvalancheDatasetType.CLASSIFICATION)
  283. return NIScenario(
  284. seq_train_dataset, seq_test_dataset,
  285. n_experiences,
  286. task_labels,
  287. shuffle=shuffle, seed=seed,
  288. balance_experiences=balance_experiences,
  289. min_class_patterns_in_exp=min_class_patterns_in_exp,
  290. fixed_exp_assignment=fixed_exp_assignment,
  291. reproducibility_data=reproducibility_data)
  292. # Here we define some high-level APIs an alias of their mid-level counterparts.
  293. # This was done mainly because the implementation for the mid-level API is now
  294. # quite stable and not particularly complex.
  295. dataset_benchmark = create_multi_dataset_generic_benchmark
  296. filelist_benchmark = create_generic_benchmark_from_filelists
  297. paths_benchmark = create_generic_benchmark_from_paths
  298. tensors_benchmark = create_generic_benchmark_from_tensor_lists
  299. lazy_benchmark = create_lazy_generic_benchmark
  300. def _one_dataset_per_exp_class_order(
  301. class_list_per_exp: Sequence[Sequence[int]],
  302. shuffle: bool, seed: Union[int, None]) -> (List[int], Dict[int, int]):
  303. """
  304. Utility function that shuffles the class order by keeping classes from the
  305. same experience together. Each experience is defined by a different entry in
  306. the class_list_per_exp parameter.
  307. :param class_list_per_exp: A list of class lists, one for each experience
  308. :param shuffle: If True, the experience order will be shuffled. If False,
  309. this function will return the concatenation of lists from the
  310. class_list_per_exp parameter.
  311. :param seed: If not None, an integer used to initialize the random
  312. number generator.
  313. :returns: A class order that keeps class IDs from the same experience
  314. together (adjacent).
  315. """
  316. dataset_order = list(range(len(class_list_per_exp)))
  317. if shuffle:
  318. if seed is not None:
  319. torch.random.manual_seed(seed)
  320. dataset_order = torch.as_tensor(dataset_order)[
  321. torch.randperm(len(dataset_order))].tolist()
  322. fixed_class_order = []
  323. classes_per_exp = {}
  324. for dataset_position, dataset_idx in enumerate(dataset_order):
  325. fixed_class_order.extend(class_list_per_exp[dataset_idx])
  326. classes_per_exp[dataset_position] = \
  327. len(class_list_per_exp[dataset_idx])
  328. return fixed_class_order, classes_per_exp
  329. def fixed_size_experience_split_strategy(
  330. experience_size: int, shuffle: bool, drop_last: bool,
  331. experience: Experience):
  332. """
  333. The default splitting strategy used by :func:`data_incremental_benchmark`.
  334. This splitting strategy simply splits the experience in smaller experiences
  335. of size `experience_size`.
  336. When taking inspiration for your custom splitting strategy, please consider
  337. that all parameters preceding `experience` are filled by
  338. :func:`data_incremental_benchmark` by using `partial` from the `functools`
  339. standard library. A custom splitting strategy must have only a single
  340. parameter: the experience. Consider wrapping your custom splitting strategy
  341. with `partial` if more parameters are needed.
  342. Also consider that the stream name of the experience can be obtained by
  343. using `experience.origin_stream.name`.
  344. :param experience_size: The experience size (number of instances).
  345. :param shuffle: If True, instances will be shuffled before splitting.
  346. :param drop_last: If True, the last mini-experience will be dropped if
  347. not of size `experience_size`
  348. :param experience: The experience to split.
  349. :return: The list of datasets that will be used to create the
  350. mini-experiences.
  351. """
  352. exp_dataset = experience.dataset
  353. exp_indices = list(range(len(exp_dataset)))
  354. result_datasets = []
  355. if shuffle:
  356. exp_indices = \
  357. torch.as_tensor(exp_indices)[
  358. torch.randperm(len(exp_indices))
  359. ].tolist()
  360. init_idx = 0
  361. while init_idx < len(exp_indices):
  362. final_idx = init_idx + experience_size # Exclusive
  363. if final_idx > len(exp_indices):
  364. if drop_last:
  365. break
  366. final_idx = len(exp_indices)
  367. result_datasets.append(AvalancheSubset(
  368. exp_dataset, indices=exp_indices[init_idx:final_idx]))
  369. init_idx = final_idx
  370. return result_datasets
  371. def data_incremental_benchmark(
  372. benchmark_instance: GenericCLScenario,
  373. experience_size: int,
  374. shuffle: bool = False,
  375. drop_last: bool = False,
  376. split_streams: Sequence[str] = ('train',),
  377. custom_split_strategy: Callable[[Experience],
  378. Sequence[AvalancheDataset]] = None,
  379. experience_factory: Callable[[GenericScenarioStream, int],
  380. Experience] = None):
  381. """
  382. High-level benchmark generator for a Data Incremental setup.
  383. This generator accepts an existing benchmark instance and returns a version
  384. of it in which experiences have been split in order to produce a
  385. Data Incremental stream.
  386. In its base form this generator will split train experiences in experiences
  387. of a fixed, configurable, size. The split can be also performed on other
  388. streams (like the test one) if needed.
  389. The `custom_split_strategy` parameter can be used if a more specific
  390. splitting is required.
  391. Beware that experience splitting is NOT executed in a lazy way. This
  392. means that the splitting process takes place immediately. Consider
  393. optimizing the split process for speed when using a custom splitting
  394. strategy.
  395. Please note that each mini-experience will have a task labels field
  396. equal to the one of the originating experience.
  397. The `complete_test_set_only` field of the resulting benchmark instance
  398. will be `True` only if the same field of original benchmark instance is
  399. `True` and if the resulting test stream contains exactly one experience.
  400. :param benchmark_instance: The benchmark to split.
  401. :param experience_size: The size of the experience, as an int. Ignored
  402. if `custom_split_strategy` is used.
  403. :param shuffle: If True, experiences will be split by first shuffling
  404. instances in each experience. This will use the default PyTorch
  405. random number generator at its current state. Defaults to False.
  406. Ignored if `custom_split_strategy` is used.
  407. :param drop_last: If True, if the last experience doesn't contain
  408. `experience_size` instances, then the last experience will be dropped.
  409. Defaults to False. Ignored if `custom_split_strategy` is used.
  410. :param split_streams: The list of streams to split. By default only the
  411. "train" stream will be split.
  412. :param custom_split_strategy: A function that implements a custom splitting
  413. strategy. The function must accept an experience and return a list
  414. of datasets each describing an experience. Defaults to None, which means
  415. that the standard splitting strategy will be used (which creates
  416. experiences of size `experience_size`).
  417. A good starting to understand the mechanism is to look at the
  418. implementation of the standard splitting function
  419. :func:`fixed_size_experience_split_strategy`.
  420. :param experience_factory: The experience factory.
  421. Defaults to :class:`GenericExperience`.
  422. :return: The Data Incremental benchmark instance.
  423. """
  424. split_strategy = custom_split_strategy
  425. if split_strategy is None:
  426. split_strategy = partial(
  427. fixed_size_experience_split_strategy, experience_size, shuffle,
  428. drop_last)
  429. stream_definitions: TStreamsUserDict = dict(
  430. benchmark_instance.stream_definitions)
  431. for stream_name in split_streams:
  432. if stream_name not in stream_definitions:
  433. raise ValueError(f'Stream {stream_name} could not be found in the '
  434. f'benchmark instance')
  435. stream = getattr(benchmark_instance, f'{stream_name}_stream')
  436. split_datasets: List[AvalancheDataset] = []
  437. split_task_labels: List[Set[int]] = []
  438. exp: Experience
  439. for exp in stream:
  440. experiences = split_strategy(exp)
  441. split_datasets += experiences
  442. for _ in range(len(experiences)):
  443. split_task_labels.append(set(exp.task_labels))
  444. stream_def = StreamUserDef(
  445. split_datasets, split_task_labels,
  446. stream_definitions[stream_name].origin_dataset,
  447. False)
  448. stream_definitions[stream_name] = stream_def
  449. complete_test_set_only = benchmark_instance.complete_test_set_only and \
  450. len(stream_definitions['test'].exps_data) == 1
  451. return GenericCLScenario(stream_definitions=stream_definitions,
  452. complete_test_set_only=complete_test_set_only,
  453. experience_factory=experience_factory)
  454. def random_validation_split_strategy(
  455. validation_size: Union[int, float],
  456. shuffle: bool,
  457. experience: Experience):
  458. """
  459. The default splitting strategy used by
  460. :func:`benchmark_with_validation_stream`.
  461. This splitting strategy simply splits the experience in two experiences (
  462. train and validation) of size `validation_size`.
  463. When taking inspiration for your custom splitting strategy, please consider
  464. that all parameters preceding `experience` are filled by
  465. :func:`benchmark_with_validation_stream` by using `partial` from the
  466. `functools` standard library. A custom splitting strategy must have only
  467. a single parameter: the experience. Consider wrapping your custom
  468. splitting strategy with `partial` if more parameters are needed.
  469. Also consider that the stream name of the experience can be obtained by
  470. using `experience.origin_stream.name`.
  471. :param validation_size: The number of instances to allocate to the
  472. validation experience. Can be an int value or a float between 0 and 1.
  473. :param shuffle: If True, instances will be shuffled before splitting.
  474. Otherwise, the first instances will be allocated to the training
  475. dataset by leaving the last ones to the validation dataset.
  476. :param experience: The experience to split.
  477. :return: A tuple containing 2 elements: the new training and validation
  478. datasets.
  479. """
  480. exp_dataset = experience.dataset
  481. exp_indices = list(range(len(exp_dataset)))
  482. if shuffle:
  483. exp_indices = \
  484. torch.as_tensor(exp_indices)[
  485. torch.randperm(len(exp_indices))
  486. ].tolist()
  487. if 0.0 <= validation_size <= 1.0:
  488. valid_n_instances = int(validation_size * len(exp_dataset))
  489. else:
  490. valid_n_instances = int(validation_size)
  491. if valid_n_instances > len(exp_dataset):
  492. raise ValueError(
  493. f'Can\'t create the validation experience: nott enough '
  494. f'instances. Required {valid_n_instances}, got only'
  495. f'{len(exp_dataset)}')
  496. train_n_instances = len(exp_dataset) - valid_n_instances
  497. result_train_dataset = AvalancheSubset(
  498. exp_dataset, indices=exp_indices[:train_n_instances])
  499. result_valid_dataset = AvalancheSubset(
  500. exp_dataset, indices=exp_indices[train_n_instances:])
  501. return result_train_dataset, result_valid_dataset
  502. def _gen_split(split_generator: Iterable[Tuple[AvalancheDataset,
  503. AvalancheDataset]]) -> \
  504. Tuple[Generator[AvalancheDataset, None, None],
  505. Generator[AvalancheDataset, None, None]]:
  506. """
  507. Internal utility function to split the train-validation generator
  508. into two distinct generators (one for the train stream and another one
  509. for the valid stream).
  510. :param split_generator: The lazy stream generator returning tuples of train
  511. and valid datasets.
  512. :return: Two generators (one for the train, one for the valuid).
  513. """
  514. # For more info: https://stackoverflow.com/a/28030261
  515. gen_a, gen_b = tee(split_generator, 2)
  516. return (a for a, b in gen_a), (b for a, b in gen_b)
  517. def _lazy_train_val_split(
  518. split_strategy: Callable[[Experience],
  519. Tuple[AvalancheDataset, AvalancheDataset]],
  520. experiences: Iterable[Experience]) -> \
  521. Generator[Tuple[AvalancheDataset, AvalancheDataset], None, None]:
  522. """
  523. Creates a generator operating around the split strategy and the
  524. experiences stream.
  525. :param split_strategy: The strategy used to split each experience in train
  526. and validation datasets.
  527. :return: A generator returning a 2 elements tuple (the train and validation
  528. datasets).
  529. """
  530. for new_experience in experiences:
  531. yield split_strategy(new_experience)
  532. def benchmark_with_validation_stream(
  533. benchmark_instance: GenericCLScenario,
  534. validation_size: Union[int, float],
  535. shuffle: bool = False,
  536. input_stream: str = 'train',
  537. output_stream: str = 'valid',
  538. custom_split_strategy: Callable[[Experience],
  539. Tuple[AvalancheDataset,
  540. AvalancheDataset]] = None,
  541. *,
  542. experience_factory: Callable[[GenericScenarioStream, int],
  543. Experience] = None,
  544. lazy_splitting: bool = None):
  545. """
  546. Helper that can be used to obtain a benchmark with a validation stream.
  547. This generator accepts an existing benchmark instance and returns a version
  548. of it in which a validation stream has been added.
  549. In its base form this generator will split train experiences to extract
  550. validation experiences of a fixed (by number of instances or relative
  551. size), configurable, size. The split can be also performed on other
  552. streams if needed and the name of the resulting validation stream can
  553. be configured too.
  554. Each validation experience will be extracted directly from a single training
  555. experience. Patterns selected for the validation experience will be removed
  556. from the training one.
  557. If shuffle is True, the validation stream will be created randomly.
  558. Beware that no kind of class balancing is done.
  559. The `custom_split_strategy` parameter can be used if a more specific
  560. splitting is required.
  561. Please note that the resulting experiences will have a task labels field
  562. equal to the one of the originating experience.
  563. Experience splitting can be executed in a lazy way. This behavior can be
  564. controlled using the `lazy_splitting` parameter. By default, experiences
  565. are split in a lazy way only when the input stream is lazily generated.
  566. :param benchmark_instance: The benchmark to split.
  567. :param validation_size: The size of the validation experience, as an int
  568. or a float between 0 and 1. Ignored if `custom_split_strategy` is used.
  569. :param shuffle: If True, patterns will be allocated to the validation
  570. stream randomly. This will use the default PyTorch random number
  571. generator at its current state. Defaults to False. Ignored if
  572. `custom_split_strategy` is used. If False, the first instances will be
  573. allocated to the training dataset by leaving the last ones to the
  574. validation dataset.
  575. :param input_stream: The name of the input stream. Defaults to 'train'.
  576. :param output_stream: The name of the output stream. Defaults to 'valid'.
  577. :param custom_split_strategy: A function that implements a custom splitting
  578. strategy. The function must accept an experience and return a tuple
  579. containing the new train and validation dataset. Defaults to None,
  580. which means that the standard splitting strategy will be used (which
  581. creates experiences according to `validation_size` and `shuffle`).
  582. A good starting to understand the mechanism is to look at the
  583. implementation of the standard splitting function
  584. :func:`random_validation_split_strategy`.
  585. :param experience_factory: The experience factory. Defaults to
  586. :class:`GenericExperience`.
  587. :param lazy_splitting: If True, the stream will be split in a lazy way.
  588. If False, the stream will be split immediately. Defaults to None, which
  589. means that the stream will be split in a lazy or non-lazy way depending
  590. on the laziness of the `input_stream`.
  591. :return: A benchmark instance in which the validation stream has been added.
  592. """
  593. split_strategy = custom_split_strategy
  594. if split_strategy is None:
  595. split_strategy = partial(
  596. random_validation_split_strategy, validation_size,
  597. shuffle)
  598. stream_definitions: TStreamsUserDict = dict(
  599. benchmark_instance.stream_definitions)
  600. streams = benchmark_instance.streams
  601. if input_stream not in streams:
  602. raise ValueError(f'Stream {input_stream} could not be found in the '
  603. f'benchmark instance')
  604. if output_stream in streams:
  605. raise ValueError(f'Stream {output_stream} already exists in the '
  606. f'benchmark instance')
  607. stream = streams[input_stream]
  608. split_lazily = lazy_splitting
  609. if split_lazily is None:
  610. split_lazily = stream_definitions[input_stream].is_lazy
  611. exps_tasks_labels = list(
  612. stream_definitions[input_stream].exps_task_labels
  613. )
  614. if not split_lazily:
  615. # Classic static splitting
  616. train_exps_source = []
  617. valid_exps_source = []
  618. exp: Experience
  619. for exp in stream:
  620. train_exp, valid_exp = split_strategy(exp)
  621. train_exps_source.append(train_exp)
  622. valid_exps_source.append(valid_exp)
  623. else:
  624. # Lazy splitting (based on a generator)
  625. split_generator = _lazy_train_val_split(split_strategy, stream)
  626. train_exps_gen, valid_exps_gen = _gen_split(split_generator)
  627. train_exps_source = (train_exps_gen, len(stream))
  628. valid_exps_source = (valid_exps_gen, len(stream))
  629. train_stream_def = \
  630. StreamUserDef(
  631. train_exps_source,
  632. exps_tasks_labels,
  633. stream_definitions[input_stream].origin_dataset,
  634. split_lazily)
  635. valid_stream_def = \
  636. StreamUserDef(
  637. valid_exps_source,
  638. exps_tasks_labels,
  639. stream_definitions[input_stream].origin_dataset,
  640. split_lazily)
  641. stream_definitions[input_stream] = train_stream_def
  642. stream_definitions[output_stream] = valid_stream_def
  643. complete_test_set_only = benchmark_instance.complete_test_set_only
  644. return GenericCLScenario(stream_definitions=stream_definitions,
  645. complete_test_set_only=complete_test_set_only,
  646. experience_factory=experience_factory)
  647. __all__ = [
  648. 'nc_benchmark',
  649. 'ni_benchmark',
  650. 'dataset_benchmark',
  651. 'filelist_benchmark',
  652. 'paths_benchmark',
  653. 'tensors_benchmark',
  654. 'data_incremental_benchmark',
  655. 'benchmark_with_validation_stream'
  656. ]