import numpy as np import typing as T from munch import munchify from cvmodelz.models import ModelFactory def class Classifier(object): def __init__(self, configuration: T.Dict): super().__init__() config = munchify(configuration) model_type = config.model_type n_classes = config.n_classes weights = config.weights self.backbone = ModelFactory.new(model_type) self.backbone.load_for_inference(weights, n_classes=n_classes, path="model/", strict=True, ) def __call__(self, im: np.ndarray): assert im.ndim in (3, 4), \ "Classifier accepts only RGB images (3D input) or a batch of images (4D input)!" if im.ndim == 3: # expand first axis im = im[None]