classifier.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import chainer
  2. import json
  3. import numpy as np
  4. import typing as T
  5. from munch import munchify
  6. from pathlib import Path
  7. from cvmodelz.models import ModelFactory
  8. from chainercv import transforms as tr
  9. class Classifier(object):
  10. def __init__(self, configuration: T.Dict, root: str):
  11. super().__init__()
  12. config = munchify(configuration)
  13. model_type = config.model_type
  14. n_classes = config.n_classes
  15. weights = Path(root, config.weights).resolve()
  16. self.input_size = config.input_size
  17. self.backbone = ModelFactory.new(model_type)
  18. self.backbone.load_for_inference(weights,
  19. n_classes=n_classes,
  20. path="model/",
  21. strict=True,
  22. )
  23. with open(Path(root, config.mapping)) as f:
  24. mapping = json.load(f)
  25. self._cls_id2ref = {int(key): value["kr"] for key, value in mapping.items()}
  26. def _transform(self, im: np.ndarray):
  27. _prepare = self.backbone.meta.prepare_func
  28. size = (self.input_size, self.input_size)
  29. # print(f"{'Orig:': <14s} {im.shape=}")
  30. im = _prepare(im, size=size, keep_ratio=True, swap_channels=False)
  31. # print(f"{'Prepare:': <14s} {im.shape=}")
  32. im = tr.center_crop(im, size)
  33. # print(f"{'CenterCrop:': <14s} {im.shape=}")
  34. return im
  35. def __call__(self, im: np.ndarray) -> int:
  36. assert im.ndim in (3, 4), \
  37. "Classifier accepts only RGB images (3D input) or a batch of images (4D input)!"
  38. if im.ndim == 3:
  39. # expand first axis
  40. # HxWxC -> 1xHxWxC
  41. im = im[None]
  42. im = [self._transform(_im) for _im in im]
  43. x = self.backbone.xp.array(im)
  44. with chainer.using_config("train", False), chainer.no_backprop_mode():
  45. pred = self.backbone(x)
  46. pred.to_cpu()
  47. cls_id = int(np.argmax(pred.array, axis=1))
  48. return self._cls_id2ref.get(cls_id, str(cls_id))