6
0
Преглед на файлове

Inference for Custom Bounding Boxes

A user can now perform classification only for  a custom bounding box that does not have a label yet. That is, a user only has to draw the bounding box but not necessarily select the label by themself.
j-bl преди 3 години
родител
ревизия
1c9c605000

+ 22 - 0
models/moth_scanner/scanner/__init__.py

@@ -1,14 +1,18 @@
+from typing import List
+
 import cv2
 import cv2
 import numpy as np
 import numpy as np
 
 
 from json import dump, load
 from json import dump, load
 
 
 from pycs.interfaces.MediaFile import MediaFile
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.Pipeline import Pipeline as Interface
 from pycs.interfaces.Pipeline import Pipeline as Interface
 
 
 from .detector import Detector
 from .detector import Detector
 from .classifier import Classifier
 from .classifier import Classifier
+from .detector import BBox
 
 
 class Scanner(Interface):
 class Scanner(Interface):
     def __init__(self, root_folder: str, configuration: dict):
     def __init__(self, root_folder: str, configuration: dict):
@@ -37,5 +41,23 @@ class Scanner(Interface):
             label = labels.get(cls_ref, cls_ref)
             label = labels.get(cls_ref, cls_ref)
             file.add_bounding_box(x0, y0, bbox.w, bbox.h, label=label)
             file.add_bounding_box(x0, y0, bbox.w, bbox.h, label=label)
 
 
+    def pure_inference(self, storage: MediaStorage, file: MediaFile, bounding_boxes: List[MediaBoundingBox]):
+
+        im = self.read_image(file.path)
+        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+        bw_im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
+
+        labels = {ml.reference: ml for ml in storage.labels()}
+
+        bbox_labels = []
+        for bbox in bounding_boxes:
+            bbox = BBox(bbox.x, bbox.y, bbox.x + bbox.w , bbox.y + bbox.h)
+            x0, y0, x1, y1 = bbox
+            cls_ref = self.classifier(bbox.crop(im, enlarge=True))
+            label = labels.get(cls_ref, cls_ref)
+            bbox_labels.append(label)
+
+        return bbox_labels
+
     def read_image(self, path: str, mode: int = cv2.IMREAD_COLOR) -> np.ndarray:
     def read_image(self, path: str, mode: int = cv2.IMREAD_COLOR) -> np.ndarray:
         return cv2.imread(path, mode)
         return cv2.imread(path, mode)

+ 6 - 0
pycs/frontend/WebServer.py

@@ -31,6 +31,7 @@ from pycs.frontend.endpoints.labels.ListLabelTree import ListLabelTree
 from pycs.frontend.endpoints.labels.ListLabels import ListLabels
 from pycs.frontend.endpoints.labels.ListLabels import ListLabels
 from pycs.frontend.endpoints.labels.RemoveLabel import RemoveLabel
 from pycs.frontend.endpoints.labels.RemoveLabel import RemoveLabel
 from pycs.frontend.endpoints.pipelines.FitModel import FitModel
 from pycs.frontend.endpoints.pipelines.FitModel import FitModel
+from pycs.frontend.endpoints.pipelines.PredictBoundingBox import PredictBoundingBox
 from pycs.frontend.endpoints.pipelines.PredictFile import PredictFile
 from pycs.frontend.endpoints.pipelines.PredictFile import PredictFile
 from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel
 from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel
 from pycs.frontend.endpoints.projects.CreateProject import CreateProject
 from pycs.frontend.endpoints.projects.CreateProject import CreateProject
@@ -339,6 +340,11 @@ class WebServer:
             view_func=PredictFile.as_view('predict_file', self.notifications,
             view_func=PredictFile.as_view('predict_file', self.notifications,
                                           self.jobs, self.pipelines)
                                           self.jobs, self.pipelines)
         )
         )
+        self.app.add_url_rule(
+            '/data/<int:file_id>/<int:bbox_id>/predict_bounding_box',
+            view_func=PredictBoundingBox.as_view('predict_bounding_box', self.notifications,
+                                          self.jobs, self.pipelines)
+        )
 
 
     def run(self):
     def run(self):
         """ start web server """
         """ start web server """

+ 60 - 0
pycs/frontend/endpoints/pipelines/PredictBoundingBox.py

@@ -0,0 +1,60 @@
+from flask import abort
+from flask import make_response
+from flask import request
+from flask.views import View
+
+from pycs.database.Result import Result
+from pycs.database.File import File
+from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel as Predict
+from pycs.frontend.notifications.NotificationList import NotificationList
+from pycs.frontend.notifications.NotificationManager import NotificationManager
+from pycs.jobs.JobGroupBusyException import JobGroupBusyException
+from pycs.jobs.JobRunner import JobRunner
+from pycs.util.PipelineCache import PipelineCache
+
+
+class PredictBoundingBox(View):
+    """
+    load a model and create predictions or a given file
+    """
+    # pylint: disable=arguments-differ
+    methods = ['POST']
+
+    def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
+        # pylint: disable=invalid-name
+        self.nm = nm
+        self.jobs = jobs
+        self.pipelines = pipelines
+
+    def dispatch_request(self, file_id, bbox_id):
+        # find file and result (=bounding box)
+        # We need the result to get (x,y,w,h)
+        file = File.get_or_404(file_id)
+        result = Result.get_or_404(bbox_id)
+
+        # extract request data
+        data = request.get_json(force=True)
+
+        if not data.get('predict', False):
+            abort(400, "predict flag is missing")
+
+        # get project and model
+        project = file.project
+
+        # create job
+        try:
+            notifications = NotificationList(self.nm)
+
+            self.jobs.run(project,
+                          'Model Interaction',
+                          f'{project.name} (create predictions)',
+                          f'{project.id}/model-interaction',
+                          Predict.load_and_pure_inference,
+                          self.pipelines, notifications, self.nm,
+                          project.id, [file.id], {file.id: [result]},
+                          progress=Predict.progress)
+
+        except JobGroupBusyException:
+            abort(400, "File prediction is already running")
+
+        return make_response()

+ 59 - 0
pycs/frontend/endpoints/pipelines/PredictModel.py

@@ -8,9 +8,11 @@ from flask.views import View
 
 
 from pycs import db
 from pycs import db
 from pycs.database.Project import Project
 from pycs.database.Project import Project
+from pycs.database.Result import Result
 from pycs.frontend.notifications.NotificationList import NotificationList
 from pycs.frontend.notifications.NotificationList import NotificationList
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.interfaces.MediaFile import MediaFile
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
 from pycs.jobs.JobRunner import JobRunner
 from pycs.jobs.JobRunner import JobRunner
@@ -122,6 +124,63 @@ class PredictModel(View):
             if pipeline is not None:
             if pipeline is not None:
                 pipelines.free_instance(model_root)
                 pipelines.free_instance(model_root)
 
 
+    @staticmethod
+    def load_and_pure_inference(pipelines: PipelineCache,
+                         notifications: NotificationList,
+                         nm: NotificationManager,
+                         project_id: int, file_filter: List[int], result_filter: dict[int, List[Result]]):
+        """
+        load the pipeline and call the execute function
+
+        :param database: database object
+        :param pipelines: pipeline cache
+        :param notifications: notification object
+        :param nm: notification manager
+        :param project_id: project id
+        :param file_filter: list of file ids
+        :param result_filter: dict of file id and list of results (bounding boxes) to classify
+        :return:
+        """
+        pipeline = None
+
+        # create new database instance
+        project = Project.query.get(project_id)
+        model_root = project.model.root_folder
+        storage = MediaStorage(project_id, notifications)
+
+        # create a list of MediaFile
+        # Also convert dict to the same key type.
+        length = len(file_filter)
+
+
+        # load pipeline
+        try:
+            pipeline = pipelines.load_from_root_folder(project_id, model_root)
+
+            # iterate over media files
+            index = 0
+            for file_id in file_filter:
+                file = project.file(file_id)
+                file = MediaFile(file, notifications)
+                bounding_boxes = [MediaBoundingBox(result) for result in result_filter[file_id]]
+
+                # Perform inference.
+                bbox_labels = pipeline.pure_inference(storage, file, bounding_boxes)
+
+                # Add the labels determined in the inference process.
+                for i, result in enumerate(result_filter[file_id]):
+                    result.label_id = bbox_labels[i].identifier
+                    result.set_origin('user', commit=True)
+                    notifications.add(nm.edit_result, result)
+
+                # yield progress
+                yield index / length, notifications
+
+                index += 1
+
+        finally:
+            if pipeline is not None:
+                pipelines.free_instance(model_root)
 
 
     @staticmethod
     @staticmethod
     def progress(progress: float, notifications: NotificationList):
     def progress(progress: float, notifications: NotificationList):

+ 1 - 1
pycs/interfaces/MediaBoundingBox.py

@@ -13,7 +13,7 @@ class MediaBoundingBox:
         self.y = result.data['y']
         self.y = result.data['y']
         self.w = result.data['w']
         self.w = result.data['w']
         self.h = result.data['h']
         self.h = result.data['h']
-        self.label = result.label
+        self.label = result.label if hasattr(self, 'label') else None
         self.frame = result.data['frame'] if 'frame' in result.data else None
         self.frame = result.data['frame'] if 'frame' in result.data else None
 
 
     def serialize(self) -> dict:
     def serialize(self) -> dict:

+ 14 - 0
pycs/interfaces/Pipeline.py

@@ -1,6 +1,7 @@
 from typing import List
 from typing import List
 
 
 from pycs.interfaces.MediaFile import MediaFile
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.MediaStorage import MediaStorage
 
 
 
 
@@ -69,6 +70,19 @@ class Pipeline:
         """
         """
         raise NotImplementedError
         raise NotImplementedError
 
 
+    def pure_inference(self, storage: MediaStorage, file: MediaFile, bounding_boxes: List[MediaBoundingBox]):
+        """
+        receive a file and a list of bounding boxes and only create a
+        classification for the given bounding boxes.
+
+        :param storage: database abstraction object
+        :param file: which should be analyzed
+        :param bounding_boxes: only perform inference for the given bounding boxes
+
+        :return: labels for the given bounding boxes
+        """
+        raise NotImplementedError
+
     def fit(self, storage: MediaStorage):
     def fit(self, storage: MediaStorage):
         """
         """
         receive a list of annotated media files and adapt the underlying model
         receive a list of annotated media files and adapt the underlying model

Файловите разлики са ограничени, защото са твърде много
+ 14713 - 1
webui/package-lock.json


+ 63 - 2
webui/src/components/media/cropped-image.vue

@@ -6,9 +6,20 @@
       <img alt="close button" src="@/assets/icons/cross.svg">
       <img alt="close button" src="@/assets/icons/cross.svg">
     </div>
     </div>
 
 
-    <div v-if="src" class="image-container">
-      <h3>{{ label }}</h3>
+    <div class="label-container">
+      <h3> {{ label }} </h3>
+
+      <div  v-if="label === 'Unknown'"
+            ref="create_predictions"
+            class="create-predictions-icon"
+            title="create prediction for this image"
+            :class="{active: isPredictionRunning}"
+            @click="predict_cropped_image">
+        <img alt="create prediction" src="@/assets/icons/rocket.svg">
+      </div>
+    </div>
 
 
+    <div v-if="src" class="image-container">
       <img alt="crop" :src="src"/>
       <img alt="crop" :src="src"/>
     </div>
     </div>
     <div v-else>
     <div v-else>
@@ -47,6 +58,22 @@ export default {
 
 
       return 'Not found';
       return 'Not found';
     }
     }
+  },
+  methods: {
+    predict_cropped_image: function () {
+      // This shouldn't happen, since the icon is only shown if a bounding box
+      // was selected.
+      if (!this.box)
+        return;
+
+      if (!this.isPredictionRunning) {
+        // TODO then / error
+        // this should become this.box.identifier...
+        this.$root.socket.post(`/data/${this.file.identifier}/${this.box.identifier}/predict_bounding_box`, {
+          predict: true
+        });
+      }
+    }
   }
   }
 }
 }
 </script>
 </script>
@@ -78,4 +105,38 @@ export default {
   border: 2px solid;
   border: 2px solid;
   max-width: 100%;
   max-width: 100%;
 }
 }
+
+.label-container {
+  display: flex;
+  flex-direction: row;
+  align-items: baseline;
+  justify-content: center;
+}
+
+.create-predictions-icon {
+  display: flex;
+  justify-content: center;
+  align-items: center;
+
+  cursor: pointer;
+
+  margin: 0.4rem;
+  border: 1px solid whitesmoke;
+  border-radius: 0.5rem;
+
+  width: 1.6rem;
+  height: 1.6rem;
+}
+
+.create-predictions-icon.active {
+  background-color: rgba(0, 0, 0, 0.2);
+  box-shadow: inset 0 0 5px 0 rgba(0, 0, 0, 0.2);
+}
+
+.create-predictions-icon > img {
+  max-width: 1.1rem;
+  max-height: 1.1rem;
+  filter: invert(1);
+}
+
 </style>
 </style>

Някои файлове не бяха показани, защото твърде много файлове са промени