6
0

classifier.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import chainer
  2. import json
  3. import numpy as np
  4. import typing as T
  5. from munch import munchify
  6. from pathlib import Path
  7. from cvmodelz.models import ModelFactory
  8. from chainercv import transforms as tr
  9. class Classifier(object):
  10. def __init__(self, configuration: T.Dict, root: str):
  11. super().__init__()
  12. config = munchify(configuration)
  13. model_type = config.model_type
  14. n_classes = config.n_classes
  15. weights = Path(root, config.weights)
  16. assert weights.exists(), \
  17. f"could not find classifier weights: {weights}"
  18. self.input_size = config.input_size
  19. self.backbone = ModelFactory.new(model_type)
  20. self.backbone.load_for_inference(weights.resolve(),
  21. n_classes=n_classes,
  22. path="model/",
  23. strict=True)
  24. with open(Path(root, config.mapping)) as f:
  25. mapping = json.load(f)
  26. self._cls_id2ref = {int(key): value["kr"] for key, value in mapping.items()}
  27. def _transform(self, im: np.ndarray):
  28. _prepare = self.backbone.meta.prepare_func
  29. size = (self.input_size, self.input_size)
  30. # print(f"{'Orig:': <14s} {im.shape=}")
  31. im = _prepare(im, size=size, keep_ratio=True, swap_channels=False)
  32. # print(f"{'Prepare:': <14s} {im.shape=}")
  33. im = tr.center_crop(im, size)
  34. # print(f"{'CenterCrop:': <14s} {im.shape=}")
  35. return im
  36. def __call__(self, im: np.ndarray) -> int:
  37. assert im.ndim in (3, 4), \
  38. "Classifier accepts only RGB images (3D input) or a batch of images (4D input)!"
  39. if im.ndim == 3:
  40. # expand first axis
  41. # HxWxC -> 1xHxWxC
  42. im = im[None]
  43. im = [self._transform(_im) for _im in im]
  44. x = self.backbone.xp.array(im)
  45. with chainer.using_config("train", False), chainer.no_backprop_mode():
  46. pred = self.backbone(x)
  47. pred.to_cpu()
  48. cls_id = int(np.argmax(pred.array, axis=1))
  49. return self._cls_id2ref.get(cls_id, str(cls_id))