Explorar o código

fixed class id to label mapping

Dimitri Korsch %!s(int64=3) %!d(string=hai) anos
pai
achega
e834fa0aa2

+ 4 - 1
models/moth_scanner/scanner/__init__.py

@@ -27,11 +27,14 @@ class Scanner(Interface):
 
         detections = self.detector(bw_im)
 
+        labels = {ml.reference: ml for ml in storage.labels()}
+
         for bbox, info in detections:
             if not info.selected:
                 continue
             x0, y0, x1, y1 = bbox
-            label = self.classifier(bbox.crop(im, enlarge=True))
+            cls_id = self.classifier(bbox.crop(im, enlarge=True))
+            label = labels.get(str(cls_id), cls_id)
             file.add_bounding_box(x0, y0, bbox.w, bbox.h, label=label)
 
     def read_image(self, path: str, mode: int = cv2.IMREAD_COLOR) -> np.ndarray:

+ 1 - 2
models/moth_scanner/scanner/classifier.py

@@ -55,5 +55,4 @@ class Classifier(object):
             pred = self.backbone(x)
         pred.to_cpu()
 
-        return str(int(np.argmax(pred.array, axis=1)))
-
+        return int(np.argmax(pred.array, axis=1))