123456789101112131415161718192021222324252627282930313233343536 |
- import numpy as np
- import typing as T
- from munch import munchify
- from cvmodelz.models import ModelFactory
- def
- class Classifier(object):
- def __init__(self, configuration: T.Dict):
- super().__init__()
- config = munchify(configuration)
- model_type = config.model_type
- n_classes = config.n_classes
- weights = config.weights
- self.backbone = ModelFactory.new(model_type)
- self.backbone.load_for_inference(weights,
- n_classes=n_classes,
- path="model/",
- strict=True,
- )
- def __call__(self, im: np.ndarray):
- assert im.ndim in (3, 4), \
- "Classifier accepts only RGB images (3D input) or a batch of images (4D input)!"
- if im.ndim == 3:
- # expand first axis
- im = im[None]
|