123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 |
- import sys
- import torch
- import unittest
- from sklearn.datasets import make_classification
- from sklearn.model_selection import train_test_split
- from torch.nn import CrossEntropyLoss
- from torch.optim import SGD
- from torch.optim.lr_scheduler import MultiStepLR
- from torch.utils.data import TensorDataset
- from avalanche.benchmarks import nc_benchmark
- from avalanche.logging import TextLogger
- from avalanche.models import SimpleMLP
- from avalanche.training.plugins import EvaluationPlugin
- from avalanche.training.plugins import StrategyPlugin, ReplayPlugin, \
- ExperienceBalancedStoragePolicy, ClassBalancedStoragePolicy
- from avalanche.training.plugins.lr_scheduling import LRSchedulerPlugin
- from avalanche.training.strategies import Naive
- class MockPlugin(StrategyPlugin):
- def __init__(self):
- super().__init__()
- self.count = 0
- self.activated = [False for _ in range(22)]
- def before_training_exp(self, strategy, **kwargs):
- self.activated[0] = True
- def after_train_dataset_adaptation(self, strategy, **kwargs):
- self.activated[1] = True
- def before_training_epoch(self, strategy, **kwargs):
- self.activated[2] = True
- def before_training_iteration(self, strategy, **kwargs):
- self.activated[3] = True
- def before_forward(self, strategy, **kwargs):
- self.activated[4] = True
- def after_forward(self, strategy, **kwargs):
- self.activated[5] = True
- def before_backward(self, strategy, **kwargs):
- self.activated[6] = True
- def after_backward(self, strategy, **kwargs):
- self.activated[7] = True
- def after_training_iteration(self, strategy, **kwargs):
- self.activated[8] = True
- def before_update(self, strategy, **kwargs):
- self.activated[9] = True
- def after_update(self, strategy, **kwargs):
- self.activated[10] = True
- def after_training_epoch(self, strategy, **kwargs):
- self.activated[11] = True
- def after_training_exp(self, strategy, **kwargs):
- self.activated[12] = True
- def before_eval(self, strategy, **kwargs):
- self.activated[13] = True
- def after_eval_dataset_adaptation(self, strategy, **kwargs):
- self.activated[14] = True
- def before_eval_exp(self, strategy, **kwargs):
- self.activated[15] = True
- def after_eval_exp(self, strategy, **kwargs):
- self.activated[16] = True
- def after_eval(self, strategy, **kwargs):
- self.activated[17] = True
- def before_eval_iteration(self, strategy, **kwargs):
- self.activated[18] = True
- def before_eval_forward(self, strategy, **kwargs):
- self.activated[19] = True
- def after_eval_forward(self, strategy, **kwargs):
- self.activated[20] = True
- def after_eval_iteration(self, strategy, **kwargs):
- self.activated[21] = True
- class PluginTests(unittest.TestCase):
- def test_callback_reachability(self):
- # Check that all the callbacks are called during
- # training and test loops.
- model = SimpleMLP(input_size=6, hidden_size=10)
- optimizer = SGD(model.parameters(), lr=1e-3)
- criterion = CrossEntropyLoss()
- benchmark = self.create_benchmark()
- plug = MockPlugin()
- strategy = Naive(model, optimizer, criterion,
- train_mb_size=100, train_epochs=1, eval_mb_size=100,
- device='cpu', plugins=[plug]
- )
- strategy.evaluator.loggers = [TextLogger(sys.stdout)]
- strategy.train(benchmark.train_stream[0], num_workers=4)
- strategy.eval([benchmark.test_stream[0]], num_workers=4)
- assert all(plug.activated)
- def create_benchmark(self, task_labels=False):
- n_samples_per_class = 20
- dataset = make_classification(
- n_samples=10 * n_samples_per_class,
- n_classes=10,
- n_features=6, n_informative=6, n_redundant=0)
- X = torch.from_numpy(dataset[0]).float()
- y = torch.from_numpy(dataset[1]).long()
- train_X, test_X, train_y, test_y = train_test_split(
- X, y, train_size=0.6, shuffle=True, stratify=y)
- train_dataset = TensorDataset(train_X, train_y)
- test_dataset = TensorDataset(test_X, test_y)
- return nc_benchmark(train_dataset, test_dataset, 5,
- task_labels=task_labels,
- fixed_class_order=list(range(10)))
- def test_scheduler_plugin(self):
- self._test_scheduler_plugin(gamma=1 / 2.,
- milestones=[2, 3],
- base_lr=4.,
- epochs=3,
- reset_lr=True,
- reset_scheduler=True,
- expected=[[4., 2., 1.],
- [4., 2., 1.]],
- )
- self._test_scheduler_plugin(gamma=1 / 2.,
- milestones=[2, 3],
- base_lr=4.,
- epochs=3,
- reset_lr=False,
- reset_scheduler=True,
- expected=[[4., 2., 1.],
- [1., .5, .25]],
- )
- self._test_scheduler_plugin(gamma=1 / 2.,
- milestones=[2, 3],
- base_lr=4.,
- epochs=3,
- reset_lr=True,
- reset_scheduler=False,
- expected=[[4., 2., 1.],
- [4., 4., 4.]],
- )
- self._test_scheduler_plugin(gamma=1 / 2.,
- milestones=[2, 3],
- base_lr=4.,
- epochs=3,
- reset_lr=False,
- reset_scheduler=False,
- expected=[[4., 2., 1.],
- [1., 1., 1.]],
- )
- def _test_scheduler_plugin(self, gamma, milestones, base_lr, epochs,
- reset_lr, reset_scheduler, expected):
- class TestPlugin(StrategyPlugin):
- def __init__(self, expected_lrs):
- super().__init__()
- self.expected_lrs = expected_lrs
- def after_training_epoch(self, strategy, **kwargs):
- exp_id = strategy.training_exp_counter
- expected_lr = self.expected_lrs[exp_id][strategy.epoch]
- for group in strategy.optimizer.param_groups:
- assert group['lr'] == expected_lr
- benchmark = self.create_benchmark()
- model = SimpleMLP(input_size=6, hidden_size=10)
- optim = SGD(model.parameters(), lr=base_lr)
- lrSchedulerPlugin = LRSchedulerPlugin(
- MultiStepLR(optim, milestones=milestones, gamma=gamma),
- reset_lr=reset_lr, reset_scheduler=reset_scheduler)
- cl_strategy = Naive(model, optim, CrossEntropyLoss(), train_mb_size=32,
- train_epochs=epochs, eval_mb_size=100,
- plugins=[lrSchedulerPlugin, TestPlugin(expected)])
- cl_strategy.train(benchmark.train_stream[0])
- cl_strategy.train(benchmark.train_stream[1])
- if __name__ == '__main__':
- unittest.main()
|