1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import numpy as np
- import typing as T
- from munch import munchify
- from pathlib import Path
- from cvmodelz.models import ModelFactory
- from chainercv import transforms as tr
- class Classifier(object):
- def __init__(self, configuration: T.Dict, root: str):
- super().__init__()
- config = munchify(configuration)
- model_type = config.model_type
- n_classes = config.n_classes
- weights = Path(root, config.weights).resolve()
- self.input_size = config.input_size
- self.backbone = ModelFactory.new(model_type)
- self.backbone.load_for_inference(weights,
- n_classes=n_classes,
- path="model/",
- strict=True,
- )
- def _transform(self, im: np.ndarray):
- _prepare = self.backbone.meta.prepare_func
- size = (self.input_size, self.input_size)
- # print(f"{'Orig:': <14s} {im.shape=}")
- im = _prepare(im, size=size, keep_ratio=True, swap_channels=False)
- # print(f"{'Prepare:': <14s} {im.shape=}")
- im = tr.center_crop(im, size)
- # print(f"{'CenterCrop:': <14s} {im.shape=}")
- return im
- def __call__(self, im: np.ndarray) -> int:
- assert im.ndim in (3, 4), \
- "Classifier accepts only RGB images (3D input) or a batch of images (4D input)!"
- if im.ndim == 3:
- # expand first axis
- # CxHxW -> 1xCxHxW
- im = im[None]
- im = [self._transform(_im) for _im in im]
- x = self.backbone.xp.array(im)
- pred = self.backbone(x)
- pred.to_cpu()
- return int(np.argmax(pred.array, axis=1))
|