classifier.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import numpy as np
  2. import typing as T
  3. from munch import munchify
  4. from pathlib import Path
  5. from cvmodelz.models import ModelFactory
  6. from chainercv import transforms as tr
  7. class Classifier(object):
  8. def __init__(self, configuration: T.Dict, root: str):
  9. super().__init__()
  10. config = munchify(configuration)
  11. model_type = config.model_type
  12. n_classes = config.n_classes
  13. weights = Path(root, config.weights).resolve()
  14. self.input_size = config.input_size
  15. self.backbone = ModelFactory.new(model_type)
  16. self.backbone.load_for_inference(weights,
  17. n_classes=n_classes,
  18. path="model/",
  19. strict=True,
  20. )
  21. def _transform(self, im: np.ndarray):
  22. _prepare = self.backbone.meta.prepare_func
  23. size = (self.input_size, self.input_size)
  24. # print(f"{'Orig:': <14s} {im.shape=}")
  25. im = _prepare(im, size=size, keep_ratio=True, swap_channels=False)
  26. # print(f"{'Prepare:': <14s} {im.shape=}")
  27. im = tr.center_crop(im, size)
  28. # print(f"{'CenterCrop:': <14s} {im.shape=}")
  29. return im
  30. def __call__(self, im: np.ndarray) -> int:
  31. assert im.ndim in (3, 4), \
  32. "Classifier accepts only RGB images (3D input) or a batch of images (4D input)!"
  33. if im.ndim == 3:
  34. # expand first axis
  35. # CxHxW -> 1xCxHxW
  36. im = im[None]
  37. im = [self._transform(_im) for _im in im]
  38. x = self.backbone.xp.array(im)
  39. pred = self.backbone(x)
  40. pred.to_cpu()
  41. return int(np.argmax(pred.array, axis=1))