classifier.py 941 B

123456789101112131415161718192021222324252627282930313233343536
  1. import numpy as np
  2. import typing as T
  3. from munch import munchify
  4. from cvmodelz.models import ModelFactory
  5. def
  6. class Classifier(object):
  7. def __init__(self, configuration: T.Dict):
  8. super().__init__()
  9. config = munchify(configuration)
  10. model_type = config.model_type
  11. n_classes = config.n_classes
  12. weights = config.weights
  13. self.backbone = ModelFactory.new(model_type)
  14. self.backbone.load_for_inference(weights,
  15. n_classes=n_classes,
  16. path="model/",
  17. strict=True,
  18. )
  19. def __call__(self, im: np.ndarray):
  20. assert im.ndim in (3, 4), \
  21. "Classifier accepts only RGB images (3D input) or a batch of images (4D input)!"
  22. if im.ndim == 3:
  23. # expand first axis
  24. im = im[None]