12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- """
- This is a simple example on how to use the Dataset inspection plugins.
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import argparse
- from datetime import datetime
- import torch
- import torch.optim.lr_scheduler
- from torch.optim import Adam
- from torchvision.transforms import (
- Compose,
- RandomCrop,
- ToTensor,
- RandomHorizontalFlip,
- Normalize,
- )
- from avalanche.benchmarks import SplitCIFAR10
- from avalanche.evaluation.metric_utils import (
- repartition_bar_chart_image_creator,
- )
- from avalanche.evaluation.metrics.labels_repartition import (
- labels_repartition_metrics,
- )
- from avalanche.evaluation.metrics.images_samples import images_samples_metrics
- from avalanche.models import SimpleMLP
- from avalanche.training.strategies import Naive
- from avalanche.training.plugins import ReplayPlugin
- from avalanche.evaluation.metrics import accuracy_metrics
- from avalanche.logging import TensorboardLogger, InteractiveLogger
- from avalanche.training.plugins import EvaluationPlugin
- def main(cuda: int):
- # --- CONFIG
- device = torch.device(
- f"cuda:{cuda}" if torch.cuda.is_available() else "cpu"
- )
- # --- SCENARIO CREATION
- scenario = SplitCIFAR10(n_experiences=2, seed=42)
- # ---------
- # MODEL CREATION
- model = SimpleMLP(num_classes=scenario.n_classes, input_size=196608 // 64)
- # choose some metrics and evaluation method
- eval_plugin = EvaluationPlugin(
- accuracy_metrics(stream=True, experience=True),
- images_samples_metrics(
- on_train=True, on_eval=True, n_cols=10, n_rows=10,
- ),
- labels_repartition_metrics(
- # image_creator=repartition_bar_chart_image_creator,
- on_train=True,
- on_eval=True,
- ),
- loggers=[
- TensorboardLogger(f"tb_data/{datetime.now()}"),
- InteractiveLogger(),
- ],
- )
- # CREATE THE STRATEGY INSTANCE (NAIVE)
- cl_strategy = Naive(
- model,
- Adam(model.parameters()),
- train_mb_size=128,
- train_epochs=1,
- eval_mb_size=128,
- device=device,
- plugins=[ReplayPlugin(mem_size=1_000)],
- evaluator=eval_plugin,
- )
- # TRAINING LOOP
- for i, experience in enumerate(scenario.train_stream, 1):
- cl_strategy.train(experience)
- cl_strategy.eval(scenario.test_stream[:i])
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--cuda",
- type=int,
- default=0,
- help="Select zero-indexed cuda device. -1 to use CPU.",
- )
- args = parser.parse_args()
- main(args.cuda)
|