Procházet zdrojové kódy

Inference for Custom Bounding Boxes

A user can now perform classification only for  a custom bounding box that does not have a label yet. That is, a user only has to draw the bounding box but not necessarily select the label by themself.
j-bl před 3 roky
rodič
revize
1c9c605000

+ 22 - 0
models/moth_scanner/scanner/__init__.py

@@ -1,14 +1,18 @@
+from typing import List
+
 import cv2
 import numpy as np
 
 from json import dump, load
 
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.Pipeline import Pipeline as Interface
 
 from .detector import Detector
 from .classifier import Classifier
+from .detector import BBox
 
 class Scanner(Interface):
     def __init__(self, root_folder: str, configuration: dict):
@@ -37,5 +41,23 @@ class Scanner(Interface):
             label = labels.get(cls_ref, cls_ref)
             file.add_bounding_box(x0, y0, bbox.w, bbox.h, label=label)
 
+    def pure_inference(self, storage: MediaStorage, file: MediaFile, bounding_boxes: List[MediaBoundingBox]):
+
+        im = self.read_image(file.path)
+        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+        bw_im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
+
+        labels = {ml.reference: ml for ml in storage.labels()}
+
+        bbox_labels = []
+        for bbox in bounding_boxes:
+            bbox = BBox(bbox.x, bbox.y, bbox.x + bbox.w , bbox.y + bbox.h)
+            x0, y0, x1, y1 = bbox
+            cls_ref = self.classifier(bbox.crop(im, enlarge=True))
+            label = labels.get(cls_ref, cls_ref)
+            bbox_labels.append(label)
+
+        return bbox_labels
+
     def read_image(self, path: str, mode: int = cv2.IMREAD_COLOR) -> np.ndarray:
         return cv2.imread(path, mode)

+ 6 - 0
pycs/frontend/WebServer.py

@@ -31,6 +31,7 @@ from pycs.frontend.endpoints.labels.ListLabelTree import ListLabelTree
 from pycs.frontend.endpoints.labels.ListLabels import ListLabels
 from pycs.frontend.endpoints.labels.RemoveLabel import RemoveLabel
 from pycs.frontend.endpoints.pipelines.FitModel import FitModel
+from pycs.frontend.endpoints.pipelines.PredictBoundingBox import PredictBoundingBox
 from pycs.frontend.endpoints.pipelines.PredictFile import PredictFile
 from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel
 from pycs.frontend.endpoints.projects.CreateProject import CreateProject
@@ -339,6 +340,11 @@ class WebServer:
             view_func=PredictFile.as_view('predict_file', self.notifications,
                                           self.jobs, self.pipelines)
         )
+        self.app.add_url_rule(
+            '/data/<int:file_id>/<int:bbox_id>/predict_bounding_box',
+            view_func=PredictBoundingBox.as_view('predict_bounding_box', self.notifications,
+                                          self.jobs, self.pipelines)
+        )
 
     def run(self):
         """ start web server """

+ 60 - 0
pycs/frontend/endpoints/pipelines/PredictBoundingBox.py

@@ -0,0 +1,60 @@
+from flask import abort
+from flask import make_response
+from flask import request
+from flask.views import View
+
+from pycs.database.Result import Result
+from pycs.database.File import File
+from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel as Predict
+from pycs.frontend.notifications.NotificationList import NotificationList
+from pycs.frontend.notifications.NotificationManager import NotificationManager
+from pycs.jobs.JobGroupBusyException import JobGroupBusyException
+from pycs.jobs.JobRunner import JobRunner
+from pycs.util.PipelineCache import PipelineCache
+
+
+class PredictBoundingBox(View):
+    """
+    load a model and create predictions or a given file
+    """
+    # pylint: disable=arguments-differ
+    methods = ['POST']
+
+    def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
+        # pylint: disable=invalid-name
+        self.nm = nm
+        self.jobs = jobs
+        self.pipelines = pipelines
+
+    def dispatch_request(self, file_id, bbox_id):
+        # find file and result (=bounding box)
+        # We need the result to get (x,y,w,h)
+        file = File.get_or_404(file_id)
+        result = Result.get_or_404(bbox_id)
+
+        # extract request data
+        data = request.get_json(force=True)
+
+        if not data.get('predict', False):
+            abort(400, "predict flag is missing")
+
+        # get project and model
+        project = file.project
+
+        # create job
+        try:
+            notifications = NotificationList(self.nm)
+
+            self.jobs.run(project,
+                          'Model Interaction',
+                          f'{project.name} (create predictions)',
+                          f'{project.id}/model-interaction',
+                          Predict.load_and_pure_inference,
+                          self.pipelines, notifications, self.nm,
+                          project.id, [file.id], {file.id: [result]},
+                          progress=Predict.progress)
+
+        except JobGroupBusyException:
+            abort(400, "File prediction is already running")
+
+        return make_response()

+ 59 - 0
pycs/frontend/endpoints/pipelines/PredictModel.py

@@ -8,9 +8,11 @@ from flask.views import View
 
 from pycs import db
 from pycs.database.Project import Project
+from pycs.database.Result import Result
 from pycs.frontend.notifications.NotificationList import NotificationList
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
 from pycs.jobs.JobRunner import JobRunner
@@ -122,6 +124,63 @@ class PredictModel(View):
             if pipeline is not None:
                 pipelines.free_instance(model_root)
 
+    @staticmethod
+    def load_and_pure_inference(pipelines: PipelineCache,
+                         notifications: NotificationList,
+                         nm: NotificationManager,
+                         project_id: int, file_filter: List[int], result_filter: dict[int, List[Result]]):
+        """
+        load the pipeline and call the execute function
+
+        :param database: database object
+        :param pipelines: pipeline cache
+        :param notifications: notification object
+        :param nm: notification manager
+        :param project_id: project id
+        :param file_filter: list of file ids
+        :param result_filter: dict of file id and list of results (bounding boxes) to classify
+        :return:
+        """
+        pipeline = None
+
+        # create new database instance
+        project = Project.query.get(project_id)
+        model_root = project.model.root_folder
+        storage = MediaStorage(project_id, notifications)
+
+        # create a list of MediaFile
+        # Also convert dict to the same key type.
+        length = len(file_filter)
+
+
+        # load pipeline
+        try:
+            pipeline = pipelines.load_from_root_folder(project_id, model_root)
+
+            # iterate over media files
+            index = 0
+            for file_id in file_filter:
+                file = project.file(file_id)
+                file = MediaFile(file, notifications)
+                bounding_boxes = [MediaBoundingBox(result) for result in result_filter[file_id]]
+
+                # Perform inference.
+                bbox_labels = pipeline.pure_inference(storage, file, bounding_boxes)
+
+                # Add the labels determined in the inference process.
+                for i, result in enumerate(result_filter[file_id]):
+                    result.label_id = bbox_labels[i].identifier
+                    result.set_origin('user', commit=True)
+                    notifications.add(nm.edit_result, result)
+
+                # yield progress
+                yield index / length, notifications
+
+                index += 1
+
+        finally:
+            if pipeline is not None:
+                pipelines.free_instance(model_root)
 
     @staticmethod
     def progress(progress: float, notifications: NotificationList):

+ 1 - 1
pycs/interfaces/MediaBoundingBox.py

@@ -13,7 +13,7 @@ class MediaBoundingBox:
         self.y = result.data['y']
         self.w = result.data['w']
         self.h = result.data['h']
-        self.label = result.label
+        self.label = result.label if hasattr(self, 'label') else None
         self.frame = result.data['frame'] if 'frame' in result.data else None
 
     def serialize(self) -> dict:

+ 14 - 0
pycs/interfaces/Pipeline.py

@@ -1,6 +1,7 @@
 from typing import List
 
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaStorage import MediaStorage
 
 
@@ -69,6 +70,19 @@ class Pipeline:
         """
         raise NotImplementedError
 
+    def pure_inference(self, storage: MediaStorage, file: MediaFile, bounding_boxes: List[MediaBoundingBox]):
+        """
+        receive a file and a list of bounding boxes and only create a
+        classification for the given bounding boxes.
+
+        :param storage: database abstraction object
+        :param file: which should be analyzed
+        :param bounding_boxes: only perform inference for the given bounding boxes
+
+        :return: labels for the given bounding boxes
+        """
+        raise NotImplementedError
+
     def fit(self, storage: MediaStorage):
         """
         receive a list of annotated media files and adapt the underlying model

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 14713 - 1
webui/package-lock.json


+ 63 - 2
webui/src/components/media/cropped-image.vue

@@ -6,9 +6,20 @@
       <img alt="close button" src="@/assets/icons/cross.svg">
     </div>
 
-    <div v-if="src" class="image-container">
-      <h3>{{ label }}</h3>
+    <div class="label-container">
+      <h3> {{ label }} </h3>
+
+      <div  v-if="label === 'Unknown'"
+            ref="create_predictions"
+            class="create-predictions-icon"
+            title="create prediction for this image"
+            :class="{active: isPredictionRunning}"
+            @click="predict_cropped_image">
+        <img alt="create prediction" src="@/assets/icons/rocket.svg">
+      </div>
+    </div>
 
+    <div v-if="src" class="image-container">
       <img alt="crop" :src="src"/>
     </div>
     <div v-else>
@@ -47,6 +58,22 @@ export default {
 
       return 'Not found';
     }
+  },
+  methods: {
+    predict_cropped_image: function () {
+      // This shouldn't happen, since the icon is only shown if a bounding box
+      // was selected.
+      if (!this.box)
+        return;
+
+      if (!this.isPredictionRunning) {
+        // TODO then / error
+        // this should become this.box.identifier...
+        this.$root.socket.post(`/data/${this.file.identifier}/${this.box.identifier}/predict_bounding_box`, {
+          predict: true
+        });
+      }
+    }
   }
 }
 </script>
@@ -78,4 +105,38 @@ export default {
   border: 2px solid;
   max-width: 100%;
 }
+
+.label-container {
+  display: flex;
+  flex-direction: row;
+  align-items: baseline;
+  justify-content: center;
+}
+
+.create-predictions-icon {
+  display: flex;
+  justify-content: center;
+  align-items: center;
+
+  cursor: pointer;
+
+  margin: 0.4rem;
+  border: 1px solid whitesmoke;
+  border-radius: 0.5rem;
+
+  width: 1.6rem;
+  height: 1.6rem;
+}
+
+.create-predictions-icon.active {
+  background-color: rgba(0, 0, 0, 0.2);
+  box-shadow: inset 0 0 5px 0 rgba(0, 0, 0, 0.2);
+}
+
+.create-predictions-icon > img {
+  max-width: 1.1rem;
+  max-height: 1.1rem;
+  filter: invert(1);
+}
+
 </style>

Některé soubory nejsou zobrazeny, neboť je v těchto rozdílových datech změněno mnoho souborů