classifier.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import chainer
  2. import numpy as np
  3. import typing as T
  4. from munch import munchify
  5. from pathlib import Path
  6. from cvmodelz.models import ModelFactory
  7. from chainercv import transforms as tr
  8. class Classifier(object):
  9. def __init__(self, configuration: T.Dict, root: str):
  10. super().__init__()
  11. config = munchify(configuration)
  12. model_type = config.model_type
  13. n_classes = config.n_classes
  14. weights = Path(root, config.weights).resolve()
  15. self.input_size = config.input_size
  16. self.backbone = ModelFactory.new(model_type)
  17. self.backbone.load_for_inference(weights,
  18. n_classes=n_classes,
  19. path="model/",
  20. strict=True,
  21. )
  22. def _transform(self, im: np.ndarray):
  23. _prepare = self.backbone.meta.prepare_func
  24. size = (self.input_size, self.input_size)
  25. # print(f"{'Orig:': <14s} {im.shape=}")
  26. im = _prepare(im, size=size, keep_ratio=True, swap_channels=False)
  27. # print(f"{'Prepare:': <14s} {im.shape=}")
  28. im = tr.center_crop(im, size)
  29. # print(f"{'CenterCrop:': <14s} {im.shape=}")
  30. return im
  31. def __call__(self, im: np.ndarray) -> int:
  32. assert im.ndim in (3, 4), \
  33. "Classifier accepts only RGB images (3D input) or a batch of images (4D input)!"
  34. if im.ndim == 3:
  35. # expand first axis
  36. # HxWxC -> 1xHxWxC
  37. im = im[None]
  38. im = [self._transform(_im) for _im in im]
  39. x = self.backbone.xp.array(im)
  40. with chainer.using_config("train", False), chainer.no_backprop_mode():
  41. pred = self.backbone(x)
  42. pred.to_cpu()
  43. return int(np.argmax(pred.array, axis=1))