task_incremental.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 01-12-2020 #
  7. # Author(s): Andrea Cossu #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This example trains on Split CIFAR10 with Naive strategy.
  13. In this example each experience has a different task label.
  14. The task label, although available, is not used at test time.
  15. """
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import argparse
  20. import torch
  21. from torch.nn import CrossEntropyLoss
  22. from torch.optim import Adam
  23. from avalanche.benchmarks.classic import SplitCIFAR10
  24. from avalanche.models import SimpleMLP
  25. from avalanche.training.strategies import Naive
  26. def main(args):
  27. # Config
  28. device = torch.device(f"cuda:{args.cuda}"
  29. if torch.cuda.is_available() and
  30. args.cuda >= 0 else "cpu")
  31. # model
  32. model = SimpleMLP(input_size=32*32*3, num_classes=10)
  33. # CL Benchmark Creation
  34. scenario = SplitCIFAR10(n_experiences=5, return_task_id=True)
  35. train_stream = scenario.train_stream
  36. test_stream = scenario.test_stream
  37. # Prepare for training & testing
  38. optimizer = Adam(model.parameters(), lr=0.01)
  39. criterion = CrossEntropyLoss()
  40. # Choose a CL strategy
  41. strategy = Naive(
  42. model=model, optimizer=optimizer, criterion=criterion,
  43. train_mb_size=128, train_epochs=3, eval_mb_size=128, device=device)
  44. # train and test loop
  45. for train_task in train_stream:
  46. strategy.train(train_task, num_workers=0)
  47. strategy.eval(test_stream)
  48. if __name__ == '__main__':
  49. parser = argparse.ArgumentParser()
  50. parser.add_argument('--cuda', type=int, default=0,
  51. help='Select zero-indexed cuda device. -1 to use CPU.')
  52. args = parser.parse_args()
  53. main(args)