import chainer
import json
import numpy as np
import typing as T

from munch import munchify
from pathlib import Path

from cvmodelz.models import ModelFactory
from chainercv import transforms as tr

class Classifier(object):

    def __init__(self, configuration: T.Dict, root: str):
        super().__init__()

        config = munchify(configuration)

        model_type = config.model_type
        n_classes = config.n_classes
        weights = Path(root, config.weights)

        assert weights.exists(), \
            f"could not find classifier weights: {weights}"

        self.input_size = config.input_size
        self.backbone = ModelFactory.new(model_type)
        self.backbone.load_for_inference(weights.resolve(),
                                         n_classes=n_classes,
                                         path="model/",
                                         strict=True)

        with open(Path(root, config.mapping)) as f:
            mapping = json.load(f)
            self._cls_id2ref = {int(key): value["kr"] for key, value in mapping.items()}

    def _transform(self, im: np.ndarray):
        _prepare = self.backbone.meta.prepare_func
        size = (self.input_size, self.input_size)

        # print(f"{'Orig:': <14s} {im.shape=}")
        im = _prepare(im, size=size, keep_ratio=True, swap_channels=False)
        # print(f"{'Prepare:': <14s} {im.shape=}")
        im = tr.center_crop(im, size)
        # print(f"{'CenterCrop:': <14s} {im.shape=}")
        return im

    def __call__(self, im: np.ndarray) -> int:
        assert im.ndim in (3, 4), \
            "Classifier accepts only RGB images (3D input) or a batch of images (4D input)!"

        if im.ndim == 3:
            # expand first axis
            # HxWxC -> 1xHxWxC
            im = im[None]

        im = [self._transform(_im) for _im in im]
        x = self.backbone.xp.array(im)
        with chainer.using_config("train", False), chainer.no_backprop_mode():
            pred = self.backbone(x)
        pred.to_cpu()

        cls_id = int(np.argmax(pred.array, axis=1))
        return self._cls_id2ref.get(cls_id, str(cls_id))