import numpy as np import typing as T from munch import munchify from pathlib import Path from cvmodelz.models import ModelFactory from chainercv import transforms as tr class Classifier(object): def __init__(self, configuration: T.Dict, root: str): super().__init__() config = munchify(configuration) model_type = config.model_type n_classes = config.n_classes weights = Path(root, config.weights).resolve() self.input_size = config.input_size self.backbone = ModelFactory.new(model_type) self.backbone.load_for_inference(weights, n_classes=n_classes, path="model/", strict=True, ) def _transform(self, im: np.ndarray): _prepare = self.backbone.meta.prepare_func size = (self.input_size, self.input_size) # print(f"{'Orig:': <14s} {im.shape=}") im = _prepare(im, size=size, keep_ratio=True, swap_channels=False) # print(f"{'Prepare:': <14s} {im.shape=}") im = tr.center_crop(im, size) # print(f"{'CenterCrop:': <14s} {im.shape=}") return im def __call__(self, im: np.ndarray) -> int: assert im.ndim in (3, 4), \ "Classifier accepts only RGB images (3D input) or a batch of images (4D input)!" if im.ndim == 3: # expand first axis # CxHxW -> 1xCxHxW im = im[None] im = [self._transform(_im) for _im in im] x = self.backbone.xp.array(im) pred = self.backbone(x) pred.to_cpu() return int(np.argmax(pred.array, axis=1))