test_dataloaders.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 15-03-2020 #
  7. # Author(s): Antonio Carta #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. import unittest
  12. import torch
  13. from torchvision.transforms import ToTensor, Compose, transforms, Resize
  14. import os
  15. import sys
  16. from sklearn.datasets import make_classification
  17. from sklearn.model_selection import train_test_split
  18. from torch.optim import SGD
  19. from torch.nn import CrossEntropyLoss
  20. from torch.utils.data import TensorDataset
  21. from avalanche.benchmarks.datasets import MNIST
  22. from avalanche.benchmarks.utils import AvalancheConcatDataset
  23. from avalanche.logging import TextLogger
  24. from avalanche.models import SimpleMLP
  25. from avalanche.training.plugins import EvaluationPlugin, ReplayPlugin
  26. from avalanche.training.strategies import Naive, Replay, CWRStar, \
  27. GDumb, LwF, AGEM, GEM, EWC, \
  28. SynapticIntelligence, JointTraining
  29. from avalanche.training.strategies.ar1 import AR1
  30. from avalanche.training.strategies.cumulative import Cumulative
  31. from avalanche.benchmarks import nc_benchmark, SplitCIFAR10
  32. from avalanche.training.utils import get_last_fc_layer
  33. from avalanche.evaluation.metrics import StreamAccuracy
  34. from avalanche.benchmarks.utils.data_loader import \
  35. ReplayDataLoader, TaskBalancedDataLoader, GroupBalancedDataLoader
  36. def get_fast_benchmark():
  37. n_samples_per_class = 100
  38. dataset = make_classification(
  39. n_samples=10 * n_samples_per_class,
  40. n_classes=10,
  41. n_features=6, n_informative=6, n_redundant=0)
  42. X = torch.from_numpy(dataset[0]).float()
  43. y = torch.from_numpy(dataset[1]).long()
  44. train_X, test_X, train_y, test_y = train_test_split(
  45. X, y, train_size=0.6, shuffle=True, stratify=y)
  46. train_dataset = TensorDataset(train_X, train_y)
  47. test_dataset = TensorDataset(test_X, test_y)
  48. my_nc_benchmark = nc_benchmark(train_dataset, test_dataset, 5,
  49. task_labels=True)
  50. return my_nc_benchmark
  51. class DataLoaderTests(unittest.TestCase):
  52. def test_basic(self):
  53. benchmark = get_fast_benchmark()
  54. ds = [el.dataset for el in benchmark.train_stream]
  55. data = AvalancheConcatDataset(ds)
  56. dl = TaskBalancedDataLoader(data)
  57. for el in dl:
  58. pass
  59. dl = GroupBalancedDataLoader(ds)
  60. for el in dl:
  61. pass
  62. dl = ReplayDataLoader(data, data)
  63. for el in dl:
  64. pass
  65. def test_dataload_reinit(self):
  66. benchmark = get_fast_benchmark()
  67. model = SimpleMLP(input_size=6, hidden_size=10)
  68. replayPlugin = ReplayPlugin(mem_size=5)
  69. cl_strategy = Naive(
  70. model,
  71. SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001),
  72. CrossEntropyLoss(), train_mb_size=16, train_epochs=1,
  73. eval_mb_size=16,
  74. plugins=[replayPlugin]
  75. )
  76. for step in benchmark.train_stream[:2]:
  77. cl_strategy.train(step)
  78. def test_dataload_batch_balancing(self):
  79. benchmark = get_fast_benchmark()
  80. batch_size = 32
  81. replayPlugin = ReplayPlugin(mem_size=20)
  82. model = SimpleMLP(input_size=6, hidden_size=10)
  83. cl_strategy = Naive(
  84. model,
  85. SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001),
  86. CrossEntropyLoss(), train_mb_size=batch_size, train_epochs=1,
  87. eval_mb_size=100, plugins=[replayPlugin]
  88. )
  89. for step in benchmark.train_stream:
  90. adapted_dataset = step.dataset
  91. dataloader = ReplayDataLoader(
  92. adapted_dataset,
  93. AvalancheConcatDataset(replayPlugin.ext_mem.values()),
  94. oversample_small_tasks=True,
  95. num_workers=0,
  96. batch_size=batch_size,
  97. shuffle=True)
  98. for mini_batch in dataloader:
  99. mb_task_labels = mini_batch[-1]
  100. lengths = []
  101. for task_id in adapted_dataset.task_set:
  102. len_task = (mb_task_labels == task_id).sum()
  103. lengths.append(len_task)
  104. if sum(lengths) == batch_size:
  105. difference = max(lengths) - min(lengths)
  106. self.assertLessEqual(difference, 1)
  107. self.assertLessEqual(sum(lengths), batch_size)
  108. cl_strategy.train(step)
  109. if __name__ == '__main__':
  110. unittest.main()