dataset_inspection.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """
  2. This is a simple example on how to use the Dataset inspection plugins.
  3. """
  4. from __future__ import absolute_import
  5. from __future__ import division
  6. from __future__ import print_function
  7. import argparse
  8. from datetime import datetime
  9. import torch
  10. import torch.optim.lr_scheduler
  11. from torch.optim import Adam
  12. from torchvision.transforms import (
  13. Compose,
  14. RandomCrop,
  15. ToTensor,
  16. RandomHorizontalFlip,
  17. Normalize,
  18. )
  19. from avalanche.benchmarks import SplitCIFAR10
  20. from avalanche.evaluation.metric_utils import (
  21. repartition_bar_chart_image_creator,
  22. )
  23. from avalanche.evaluation.metrics.labels_repartition import (
  24. labels_repartition_metrics,
  25. )
  26. from avalanche.evaluation.metrics.images_samples import images_samples_metrics
  27. from avalanche.models import SimpleMLP
  28. from avalanche.training.strategies import Naive
  29. from avalanche.training.plugins import ReplayPlugin
  30. from avalanche.evaluation.metrics import accuracy_metrics
  31. from avalanche.logging import TensorboardLogger, InteractiveLogger
  32. from avalanche.training.plugins import EvaluationPlugin
  33. def main(cuda: int):
  34. # --- CONFIG
  35. device = torch.device(
  36. f"cuda:{cuda}" if torch.cuda.is_available() else "cpu"
  37. )
  38. # --- SCENARIO CREATION
  39. scenario = SplitCIFAR10(n_experiences=2, seed=42)
  40. # ---------
  41. # MODEL CREATION
  42. model = SimpleMLP(num_classes=scenario.n_classes, input_size=196608 // 64)
  43. # choose some metrics and evaluation method
  44. eval_plugin = EvaluationPlugin(
  45. accuracy_metrics(stream=True, experience=True),
  46. images_samples_metrics(
  47. on_train=True, on_eval=True, n_cols=10, n_rows=10,
  48. ),
  49. labels_repartition_metrics(
  50. # image_creator=repartition_bar_chart_image_creator,
  51. on_train=True,
  52. on_eval=True,
  53. ),
  54. loggers=[
  55. TensorboardLogger(f"tb_data/{datetime.now()}"),
  56. InteractiveLogger(),
  57. ],
  58. )
  59. # CREATE THE STRATEGY INSTANCE (NAIVE)
  60. cl_strategy = Naive(
  61. model,
  62. Adam(model.parameters()),
  63. train_mb_size=128,
  64. train_epochs=1,
  65. eval_mb_size=128,
  66. device=device,
  67. plugins=[ReplayPlugin(mem_size=1_000)],
  68. evaluator=eval_plugin,
  69. )
  70. # TRAINING LOOP
  71. for i, experience in enumerate(scenario.train_stream, 1):
  72. cl_strategy.train(experience)
  73. cl_strategy.eval(scenario.test_stream[:i])
  74. if __name__ == "__main__":
  75. parser = argparse.ArgumentParser()
  76. parser.add_argument(
  77. "--cuda",
  78. type=int,
  79. default=0,
  80. help="Select zero-indexed cuda device. -1 to use CPU.",
  81. )
  82. args = parser.parse_args()
  83. main(args.cuda)