test_plugins.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import sys
  2. import torch
  3. import unittest
  4. from sklearn.datasets import make_classification
  5. from sklearn.model_selection import train_test_split
  6. from torch.nn import CrossEntropyLoss
  7. from torch.optim import SGD
  8. from torch.optim.lr_scheduler import MultiStepLR
  9. from torch.utils.data import TensorDataset
  10. from avalanche.benchmarks import nc_benchmark
  11. from avalanche.logging import TextLogger
  12. from avalanche.models import SimpleMLP
  13. from avalanche.training.plugins import EvaluationPlugin
  14. from avalanche.training.plugins import StrategyPlugin, ReplayPlugin, \
  15. ExperienceBalancedStoragePolicy, ClassBalancedStoragePolicy
  16. from avalanche.training.plugins.lr_scheduling import LRSchedulerPlugin
  17. from avalanche.training.strategies import Naive
  18. class MockPlugin(StrategyPlugin):
  19. def __init__(self):
  20. super().__init__()
  21. self.count = 0
  22. self.activated = [False for _ in range(22)]
  23. def before_training_exp(self, strategy, **kwargs):
  24. self.activated[0] = True
  25. def after_train_dataset_adaptation(self, strategy, **kwargs):
  26. self.activated[1] = True
  27. def before_training_epoch(self, strategy, **kwargs):
  28. self.activated[2] = True
  29. def before_training_iteration(self, strategy, **kwargs):
  30. self.activated[3] = True
  31. def before_forward(self, strategy, **kwargs):
  32. self.activated[4] = True
  33. def after_forward(self, strategy, **kwargs):
  34. self.activated[5] = True
  35. def before_backward(self, strategy, **kwargs):
  36. self.activated[6] = True
  37. def after_backward(self, strategy, **kwargs):
  38. self.activated[7] = True
  39. def after_training_iteration(self, strategy, **kwargs):
  40. self.activated[8] = True
  41. def before_update(self, strategy, **kwargs):
  42. self.activated[9] = True
  43. def after_update(self, strategy, **kwargs):
  44. self.activated[10] = True
  45. def after_training_epoch(self, strategy, **kwargs):
  46. self.activated[11] = True
  47. def after_training_exp(self, strategy, **kwargs):
  48. self.activated[12] = True
  49. def before_eval(self, strategy, **kwargs):
  50. self.activated[13] = True
  51. def after_eval_dataset_adaptation(self, strategy, **kwargs):
  52. self.activated[14] = True
  53. def before_eval_exp(self, strategy, **kwargs):
  54. self.activated[15] = True
  55. def after_eval_exp(self, strategy, **kwargs):
  56. self.activated[16] = True
  57. def after_eval(self, strategy, **kwargs):
  58. self.activated[17] = True
  59. def before_eval_iteration(self, strategy, **kwargs):
  60. self.activated[18] = True
  61. def before_eval_forward(self, strategy, **kwargs):
  62. self.activated[19] = True
  63. def after_eval_forward(self, strategy, **kwargs):
  64. self.activated[20] = True
  65. def after_eval_iteration(self, strategy, **kwargs):
  66. self.activated[21] = True
  67. class PluginTests(unittest.TestCase):
  68. def test_callback_reachability(self):
  69. # Check that all the callbacks are called during
  70. # training and test loops.
  71. model = SimpleMLP(input_size=6, hidden_size=10)
  72. optimizer = SGD(model.parameters(), lr=1e-3)
  73. criterion = CrossEntropyLoss()
  74. benchmark = self.create_benchmark()
  75. plug = MockPlugin()
  76. strategy = Naive(model, optimizer, criterion,
  77. train_mb_size=100, train_epochs=1, eval_mb_size=100,
  78. device='cpu', plugins=[plug]
  79. )
  80. strategy.evaluator.loggers = [TextLogger(sys.stdout)]
  81. strategy.train(benchmark.train_stream[0], num_workers=4)
  82. strategy.eval([benchmark.test_stream[0]], num_workers=4)
  83. assert all(plug.activated)
  84. def create_benchmark(self, task_labels=False):
  85. n_samples_per_class = 20
  86. dataset = make_classification(
  87. n_samples=10 * n_samples_per_class,
  88. n_classes=10,
  89. n_features=6, n_informative=6, n_redundant=0)
  90. X = torch.from_numpy(dataset[0]).float()
  91. y = torch.from_numpy(dataset[1]).long()
  92. train_X, test_X, train_y, test_y = train_test_split(
  93. X, y, train_size=0.6, shuffle=True, stratify=y)
  94. train_dataset = TensorDataset(train_X, train_y)
  95. test_dataset = TensorDataset(test_X, test_y)
  96. return nc_benchmark(train_dataset, test_dataset, 5,
  97. task_labels=task_labels,
  98. fixed_class_order=list(range(10)))
  99. def test_scheduler_plugin(self):
  100. self._test_scheduler_plugin(gamma=1 / 2.,
  101. milestones=[2, 3],
  102. base_lr=4.,
  103. epochs=3,
  104. reset_lr=True,
  105. reset_scheduler=True,
  106. expected=[[4., 2., 1.],
  107. [4., 2., 1.]],
  108. )
  109. self._test_scheduler_plugin(gamma=1 / 2.,
  110. milestones=[2, 3],
  111. base_lr=4.,
  112. epochs=3,
  113. reset_lr=False,
  114. reset_scheduler=True,
  115. expected=[[4., 2., 1.],
  116. [1., .5, .25]],
  117. )
  118. self._test_scheduler_plugin(gamma=1 / 2.,
  119. milestones=[2, 3],
  120. base_lr=4.,
  121. epochs=3,
  122. reset_lr=True,
  123. reset_scheduler=False,
  124. expected=[[4., 2., 1.],
  125. [4., 4., 4.]],
  126. )
  127. self._test_scheduler_plugin(gamma=1 / 2.,
  128. milestones=[2, 3],
  129. base_lr=4.,
  130. epochs=3,
  131. reset_lr=False,
  132. reset_scheduler=False,
  133. expected=[[4., 2., 1.],
  134. [1., 1., 1.]],
  135. )
  136. def _test_scheduler_plugin(self, gamma, milestones, base_lr, epochs,
  137. reset_lr, reset_scheduler, expected):
  138. class TestPlugin(StrategyPlugin):
  139. def __init__(self, expected_lrs):
  140. super().__init__()
  141. self.expected_lrs = expected_lrs
  142. def after_training_epoch(self, strategy, **kwargs):
  143. exp_id = strategy.training_exp_counter
  144. expected_lr = self.expected_lrs[exp_id][strategy.epoch]
  145. for group in strategy.optimizer.param_groups:
  146. assert group['lr'] == expected_lr
  147. benchmark = self.create_benchmark()
  148. model = SimpleMLP(input_size=6, hidden_size=10)
  149. optim = SGD(model.parameters(), lr=base_lr)
  150. lrSchedulerPlugin = LRSchedulerPlugin(
  151. MultiStepLR(optim, milestones=milestones, gamma=gamma),
  152. reset_lr=reset_lr, reset_scheduler=reset_scheduler)
  153. cl_strategy = Naive(model, optim, CrossEntropyLoss(), train_mb_size=32,
  154. train_epochs=epochs, eval_mb_size=100,
  155. plugins=[lrSchedulerPlugin, TestPlugin(expected)])
  156. cl_strategy.train(benchmark.train_stream[0])
  157. cl_strategy.train(benchmark.train_stream[1])
  158. if __name__ == '__main__':
  159. unittest.main()