creation.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import numpy as np
  2. import test_utils
  3. import unittest
  4. from cvmodelz.classifiers import Classifier
  5. from cvmodelz.classifiers import MeanModelClassifier
  6. from cvmodelz.models import ModelFactory
  7. class ClassifierCreationTests(unittest.TestCase):
  8. def new_clf(self, key, **kwargs):
  9. model = ModelFactory.new(key, pretrained_model=None)
  10. return model, Classifier(model, **kwargs)
  11. def creation(self, key):
  12. model, clf = self.new_clf(key)
  13. self.assertIs(clf.model, model)
  14. def loss_computation(self, key):
  15. model, clf = self.new_clf(key)
  16. in_size = clf.model.meta.input_size
  17. X = clf.xp.ones((4, 3, in_size, in_size), dtype=np.float32)
  18. y = clf.xp.random.choice(clf.n_classes, size=4)
  19. loss = clf(X, y)
  20. self.assertIsNotNone(loss)
  21. self.assertEqual(loss.ndim, 0)
  22. self.assertEqual(loss.shape, ())
  23. class MeanModelClassifierCreationTests(ClassifierCreationTests):
  24. def new_clf(self, key, **kwargs):
  25. model = ModelFactory.new(key, pretrained_model=None)
  26. return model, MeanModelClassifier(model, **kwargs)
  27. test_utils.add_tests(ClassifierCreationTests.creation,
  28. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
  29. test_utils.add_tests(MeanModelClassifierCreationTests.creation,
  30. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
  31. test_utils.add_tests(ClassifierCreationTests.loss_computation,
  32. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
  33. test_utils.add_tests(MeanModelClassifierCreationTests.loss_computation,
  34. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))