|
@@ -1,4 +1,5 @@
|
|
|
import chainer
|
|
|
+import json
|
|
|
import numpy as np
|
|
|
import typing as T
|
|
|
|
|
@@ -27,6 +28,10 @@ class Classifier(object):
|
|
|
strict=True,
|
|
|
)
|
|
|
|
|
|
+ with open(Path(root, config.mapping)) as f:
|
|
|
+ mapping = json.load(f)
|
|
|
+ self._cls_id2ref = {int(key): value["kr"] for key, value in mapping.items()}
|
|
|
+
|
|
|
def _transform(self, im: np.ndarray):
|
|
|
_prepare = self.backbone.meta.prepare_func
|
|
|
size = (self.input_size, self.input_size)
|
|
@@ -55,4 +60,5 @@ class Classifier(object):
|
|
|
pred = self.backbone(x)
|
|
|
pred.to_cpu()
|
|
|
|
|
|
- return int(np.argmax(pred.array, axis=1))
|
|
|
+ cls_id = int(np.argmax(pred.array, axis=1))
|
|
|
+ return self._cls_id2ref.get(cls_id, str(cls_id))
|