12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- ################################################################################
- # Copyright (c) 2021 ContinualAI. #
- # Copyrights licensed under the MIT License. #
- # See the accompanying LICENSE file for terms. #
- # #
- # Date: 1-06-2020 #
- # Author(s): Andrea Cossu #
- # E-mail: contact@continualai.org #
- # Website: avalanche.continualai.org #
- ################################################################################
- import unittest
- from torch.optim import SGD
- from torch.nn import CrossEntropyLoss
- from avalanche.models import SimpleMLP
- from avalanche.training.plugins import EvaluationPlugin
- from avalanche.training.strategies import Naive
- from avalanche.evaluation.metrics import accuracy_metrics
- from tests.unit_tests_utils import get_fast_benchmark
- class TestStreamCompleteness(unittest.TestCase):
- @classmethod
- def setUp(cls) -> None:
- cls.model = SimpleMLP(input_size=6, hidden_size=10)
- cls.optimizer = SGD(cls.model.parameters(), lr=1e-3)
- cls.criterion = CrossEntropyLoss()
- cls.benchmark = get_fast_benchmark()
- def test_raise_error(self):
- eval_plugin = EvaluationPlugin(accuracy_metrics(stream=True),
- loggers=None,
- benchmark=self.benchmark,
- strict_checks=True)
- strategy = Naive(self.model, self.optimizer, self.criterion,
- train_epochs=2, eval_every=-1,
- evaluator=eval_plugin)
- for exp in self.benchmark.train_stream:
- strategy.train(exp)
- strategy.eval(self.benchmark.test_stream)
- with self.assertRaises(ValueError):
- strategy.eval(self.benchmark.test_stream[:2])
- def test_raise_warning(self):
- eval_plugin = EvaluationPlugin(accuracy_metrics(stream=True),
- loggers=None,
- benchmark=self.benchmark,
- strict_checks=False)
- strategy = Naive(self.model, self.optimizer, self.criterion,
- train_epochs=2, eval_every=-1,
- evaluator=eval_plugin)
- for exp in self.benchmark.train_stream:
- strategy.train(exp)
- strategy.eval(self.benchmark.test_stream)
- with self.assertWarns(UserWarning):
- strategy.eval(self.benchmark.test_stream[:2])
- def test_no_errors(self):
- eval_plugin = EvaluationPlugin(accuracy_metrics(stream=True),
- loggers=None,
- benchmark=self.benchmark,
- strict_checks=True)
- strategy = Naive(self.model, self.optimizer, self.criterion,
- train_epochs=2, eval_every=0,
- evaluator=eval_plugin)
- for exp in self.benchmark.train_stream:
- strategy.train(exp, eval_streams=[
- self.benchmark.test_stream])
- strategy.eval(self.benchmark.test_stream)
|