|
@@ -0,0 +1,36 @@
|
|
|
+import numpy as np
|
|
|
+import typing as T
|
|
|
+
|
|
|
+from munch import munchify
|
|
|
+
|
|
|
+from cvmodelz.models import ModelFactory
|
|
|
+
|
|
|
+
|
|
|
+def
|
|
|
+
|
|
|
+class Classifier(object):
|
|
|
+
|
|
|
+ def __init__(self, configuration: T.Dict):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ config = munchify(configuration)
|
|
|
+
|
|
|
+ model_type = config.model_type
|
|
|
+ n_classes = config.n_classes
|
|
|
+ weights = config.weights
|
|
|
+
|
|
|
+ self.backbone = ModelFactory.new(model_type)
|
|
|
+ self.backbone.load_for_inference(weights,
|
|
|
+ n_classes=n_classes,
|
|
|
+ path="model/",
|
|
|
+ strict=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ def __call__(self, im: np.ndarray):
|
|
|
+ assert im.ndim in (3, 4), \
|
|
|
+ "Classifier accepts only RGB images (3D input) or a batch of images (4D input)!"
|
|
|
+
|
|
|
+ if im.ndim == 3:
|
|
|
+ # expand first axis
|
|
|
+ im = im[None]
|
|
|
+
|