mean_scores.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import argparse
  5. from datetime import datetime
  6. import torch
  7. import torch.optim.lr_scheduler
  8. from torch.optim import Adam
  9. from avalanche.benchmarks import SplitMNIST
  10. from avalanche.evaluation.metrics.mean_scores import mean_scores_metrics
  11. from avalanche.models import SimpleMLP
  12. from avalanche.training.strategies import Naive
  13. from avalanche.training.plugins import ReplayPlugin
  14. from avalanche.evaluation.metrics import accuracy_metrics
  15. from avalanche.logging import TensorboardLogger, InteractiveLogger
  16. from avalanche.training.plugins import EvaluationPlugin
  17. def main(cuda: int):
  18. # --- CONFIG
  19. device = torch.device(
  20. f"cuda:{cuda}" if torch.cuda.is_available() else "cpu"
  21. )
  22. # --- SCENARIO CREATION
  23. scenario = SplitMNIST(n_experiences=5, seed=42)
  24. # ---------
  25. # MODEL CREATION
  26. model = SimpleMLP(num_classes=scenario.n_classes)
  27. # choose some metrics and evaluation method
  28. eval_plugin = EvaluationPlugin(
  29. accuracy_metrics(stream=True, experience=True),
  30. mean_scores_metrics(on_train=True, on_eval=True),
  31. loggers=[
  32. TensorboardLogger(f"tb_data/{datetime.now()}"),
  33. InteractiveLogger(),
  34. ],
  35. )
  36. # CREATE THE STRATEGY INSTANCE (NAIVE)
  37. cl_strategy = Naive(
  38. model,
  39. Adam(model.parameters()),
  40. train_mb_size=128,
  41. train_epochs=2,
  42. eval_mb_size=128,
  43. device=device,
  44. plugins=[ReplayPlugin(mem_size=100)],
  45. evaluator=eval_plugin,
  46. )
  47. # TRAINING LOOP
  48. for i, experience in enumerate(scenario.train_stream, 1):
  49. cl_strategy.train(experience)
  50. cl_strategy.eval(scenario.test_stream[:i])
  51. if __name__ == "__main__":
  52. parser = argparse.ArgumentParser()
  53. parser.add_argument(
  54. "--cuda",
  55. type=int,
  56. default=0,
  57. help="Select zero-indexed cuda device. -1 to use CPU.",
  58. )
  59. args = parser.parse_args()
  60. main(args.cuda)