dataloader.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. from torch.nn import CrossEntropyLoss
  2. from torch.optim import SGD
  3. from avalanche.benchmarks import SplitMNIST
  4. from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader
  5. from avalanche.models import SimpleMLP
  6. from avalanche.training.strategies import Naive, Cumulative
  7. class MyCumulativeStrategy(Cumulative):
  8. def make_train_dataloader(self, shuffle=True, **kwargs):
  9. # you can override make_train_dataloader to change the
  10. # strategy's dataloader
  11. # remember to iterate over self.adapted_dataset
  12. self.dataloader = TaskBalancedDataLoader(
  13. self.adapted_dataset,
  14. batch_size=self.train_mb_size)
  15. if __name__ == '__main__':
  16. benchmark = SplitMNIST(n_experiences=5)
  17. model = SimpleMLP(input_size=784, hidden_size=10)
  18. opt = SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)
  19. # we use our custom strategy to change the dataloading policy.
  20. cl_strategy = MyCumulativeStrategy(
  21. model, opt, CrossEntropyLoss(), train_epochs=1,
  22. train_mb_size=512, eval_mb_size=512)
  23. for step in benchmark.train_stream:
  24. cl_strategy.train(step)
  25. cl_strategy.eval(step)
  26. # If you don't use avalanche's strategies you can also use the dataloader
  27. # directly to iterate the data
  28. data = step.dataset
  29. dl = TaskBalancedDataLoader(data)
  30. for x, y, t in dl:
  31. # by default minibatches in Avalanche have the form <x, y, ..., t>
  32. # with arbitrary additional tensors between y and t.
  33. print(x, y, t)
  34. break