test_stream_completeness.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 1-06-2020 #
  7. # Author(s): Andrea Cossu #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. import unittest
  12. from torch.optim import SGD
  13. from torch.nn import CrossEntropyLoss
  14. from avalanche.models import SimpleMLP
  15. from avalanche.training.plugins import EvaluationPlugin
  16. from avalanche.training.strategies import Naive
  17. from avalanche.evaluation.metrics import accuracy_metrics
  18. from tests.unit_tests_utils import get_fast_benchmark
  19. class TestStreamCompleteness(unittest.TestCase):
  20. @classmethod
  21. def setUp(cls) -> None:
  22. cls.model = SimpleMLP(input_size=6, hidden_size=10)
  23. cls.optimizer = SGD(cls.model.parameters(), lr=1e-3)
  24. cls.criterion = CrossEntropyLoss()
  25. cls.benchmark = get_fast_benchmark()
  26. def test_raise_error(self):
  27. eval_plugin = EvaluationPlugin(accuracy_metrics(stream=True),
  28. loggers=None,
  29. benchmark=self.benchmark,
  30. strict_checks=True)
  31. strategy = Naive(self.model, self.optimizer, self.criterion,
  32. train_epochs=2, eval_every=-1,
  33. evaluator=eval_plugin)
  34. for exp in self.benchmark.train_stream:
  35. strategy.train(exp)
  36. strategy.eval(self.benchmark.test_stream)
  37. with self.assertRaises(ValueError):
  38. strategy.eval(self.benchmark.test_stream[:2])
  39. def test_raise_warning(self):
  40. eval_plugin = EvaluationPlugin(accuracy_metrics(stream=True),
  41. loggers=None,
  42. benchmark=self.benchmark,
  43. strict_checks=False)
  44. strategy = Naive(self.model, self.optimizer, self.criterion,
  45. train_epochs=2, eval_every=-1,
  46. evaluator=eval_plugin)
  47. for exp in self.benchmark.train_stream:
  48. strategy.train(exp)
  49. strategy.eval(self.benchmark.test_stream)
  50. with self.assertWarns(UserWarning):
  51. strategy.eval(self.benchmark.test_stream[:2])
  52. def test_no_errors(self):
  53. eval_plugin = EvaluationPlugin(accuracy_metrics(stream=True),
  54. loggers=None,
  55. benchmark=self.benchmark,
  56. strict_checks=True)
  57. strategy = Naive(self.model, self.optimizer, self.criterion,
  58. train_epochs=2, eval_every=0,
  59. evaluator=eval_plugin)
  60. for exp in self.benchmark.train_stream:
  61. strategy.train(exp, eval_streams=[
  62. self.benchmark.test_stream])
  63. strategy.eval(self.benchmark.test_stream)