creation.py 1008 B

12345678910111213141516171819202122232425262728293031323334353637
  1. import numpy as np
  2. import test_utils
  3. import unittest
  4. from cvmodelz.classifiers import Classifier
  5. from cvmodelz.models import ModelFactory
  6. class ClassifierCreationTests(unittest.TestCase):
  7. def new_clf(self, key, **kwargs):
  8. model = ModelFactory.new(key, pretrained_model=None)
  9. return model, Classifier(model, **kwargs)
  10. def creation(self, key):
  11. model, clf = self.new_clf(key)
  12. self.assertIs(clf.model, model)
  13. def loss_computation(self, key):
  14. model, clf = self.new_clf(key)
  15. in_size = clf.model.meta.input_size
  16. X = clf.xp.ones((4, 3, in_size, in_size), dtype=np.float32)
  17. y = clf.xp.random.choice(clf.n_classes, size=4)
  18. loss = clf(X, y)
  19. self.assertIsNotNone(loss)
  20. self.assertEqual(loss.ndim, 0)
  21. self.assertEqual(loss.shape, ())
  22. test_utils.add_tests(ClassifierCreationTests.creation,
  23. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
  24. test_utils.add_tests(ClassifierCreationTests.loss_computation,
  25. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))