浏览代码

Added support for single bounding box inference

Dimitri Korsch 3 年之前
父节点
当前提交
5b4bdb32a3

+ 21 - 1
models/moth_scanner/scanner/__init__.py

@@ -1,14 +1,18 @@
+from typing import List
+
 import cv2
 import numpy as np
 
 from json import dump, load
 
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.Pipeline import Pipeline as Interface
 
 from .detector import Detector
 from .classifier import Classifier
+from .detector import BBox
 
 class Scanner(Interface):
     def __init__(self, root_folder: str, configuration: dict):
@@ -20,7 +24,6 @@ class Scanner(Interface):
         pass
 
     def execute(self, storage: MediaStorage, file: MediaFile):
-
         im = self.read_image(file.path)
         im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         bw_im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
@@ -37,5 +40,22 @@ class Scanner(Interface):
             label = labels.get(cls_ref, cls_ref)
             file.add_bounding_box(x0, y0, bbox.w, bbox.h, label=label)
 
+    def pure_inference(self, storage: MediaStorage, file: MediaFile, bounding_boxes: List[MediaBoundingBox]):
+        im = self.read_image(file.path)
+        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+        bw_im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
+
+        labels = {ml.reference: ml for ml in storage.labels()}
+
+        bbox_labels = []
+        for bbox in bounding_boxes:
+            bbox = BBox(bbox.x, bbox.y, bbox.x + bbox.w , bbox.y + bbox.h)
+            x0, y0, x1, y1 = bbox
+            cls_ref = self.classifier(bbox.crop(im, enlarge=True))
+            label = labels.get(cls_ref, cls_ref)
+            bbox_labels.append(label)
+
+        return bbox_labels
+
     def read_image(self, path: str, mode: int = cv2.IMREAD_COLOR) -> np.ndarray:
         return cv2.imread(path, mode)

+ 6 - 0
pycs/frontend/WebServer.py

@@ -31,6 +31,7 @@ from pycs.frontend.endpoints.labels.ListLabelTree import ListLabelTree
 from pycs.frontend.endpoints.labels.ListLabels import ListLabels
 from pycs.frontend.endpoints.labels.RemoveLabel import RemoveLabel
 from pycs.frontend.endpoints.pipelines.FitModel import FitModel
+from pycs.frontend.endpoints.pipelines.PredictBoundingBox import PredictBoundingBox
 from pycs.frontend.endpoints.pipelines.PredictFile import PredictFile
 from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel
 from pycs.frontend.endpoints.projects.CreateProject import CreateProject
@@ -339,6 +340,11 @@ class WebServer:
             view_func=PredictFile.as_view('predict_file', self.notifications,
                                           self.jobs, self.pipelines)
         )
+        self.app.add_url_rule(
+            '/data/<int:file_id>/<int:bbox_id>/predict_bounding_box',
+            view_func=PredictBoundingBox.as_view('predict_bounding_box', self.notifications,
+                                          self.jobs, self.pipelines)
+        )
 
     def run(self):
         """ start web server """

+ 60 - 0
pycs/frontend/endpoints/pipelines/PredictBoundingBox.py

@@ -0,0 +1,60 @@
+from flask import abort
+from flask import make_response
+from flask import request
+from flask.views import View
+
+from pycs.database.Result import Result
+from pycs.database.File import File
+from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel as Predict
+from pycs.frontend.notifications.NotificationList import NotificationList
+from pycs.frontend.notifications.NotificationManager import NotificationManager
+from pycs.jobs.JobGroupBusyException import JobGroupBusyException
+from pycs.jobs.JobRunner import JobRunner
+from pycs.util.PipelineCache import PipelineCache
+
+
+class PredictBoundingBox(View):
+    """
+    load a model and create predictions or a given file
+    """
+    # pylint: disable=arguments-differ
+    methods = ['POST']
+
+    def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
+        # pylint: disable=invalid-name
+        self.nm = nm
+        self.jobs = jobs
+        self.pipelines = pipelines
+
+    def dispatch_request(self, file_id, bbox_id):
+        # find file and result (=bounding box)
+        # We need the result to get (x,y,w,h)
+        file = File.get_or_404(file_id)
+        result = Result.get_or_404(bbox_id)
+
+        # extract request data
+        data = request.get_json(force=True)
+
+        if not data.get('predict', False):
+            abort(400, "predict flag is missing")
+
+        # get project and model
+        project = file.project
+
+        # create job
+        try:
+            notifications = NotificationList(self.nm)
+
+            self.jobs.run(project,
+                          'Model Interaction',
+                          f'{project.name} (create predictions)',
+                          f'{project.id}/model-interaction',
+                          Predict.load_and_pure_inference,
+                          self.pipelines, notifications, self.nm,
+                          project.id, [file.id], {file.id: [result]},
+                          progress=Predict.progress)
+
+        except JobGroupBusyException:
+            abort(400, "File prediction is already running")
+
+        return make_response()

+ 60 - 0
pycs/frontend/endpoints/pipelines/PredictModel.py

@@ -8,9 +8,11 @@ from flask.views import View
 
 from pycs import db
 from pycs.database.Project import Project
+from pycs.database.Result import Result
 from pycs.frontend.notifications.NotificationList import NotificationList
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
 from pycs.jobs.JobRunner import JobRunner
@@ -122,6 +124,64 @@ class PredictModel(View):
             if pipeline is not None:
                 pipelines.free_instance(model_root)
 
+    @staticmethod
+    def load_and_pure_inference(pipelines: PipelineCache,
+                         notifications: NotificationList,
+                         notification_manager: NotificationManager,
+                         project_id: int, file_filter: List[int],
+                         result_filter: dict[int, List[Result]]):
+        """
+        load the pipeline and call the execute function
+
+        :param database: database object
+        :param pipelines: pipeline cache
+        :param notifications: notification object
+        :param notification_manager: notification manager
+        :param project_id: project id
+        :param file_filter: list of file ids
+        :param result_filter: dict of file id and list of results to classify
+        :return:
+        """
+        pipeline = None
+
+        # create new database instance
+        project = Project.query.get(project_id)
+        model_root = project.model.root_folder
+        storage = MediaStorage(project_id, notifications)
+
+        # create a list of MediaFile
+        # Also convert dict to the same key type.
+        length = len(file_filter)
+
+
+        # load pipeline
+        try:
+            pipeline = pipelines.load_from_root_folder(project_id, model_root)
+
+            # iterate over media files
+            index = 0
+            for file_id in file_filter:
+                file = project.file(file_id)
+                file = MediaFile(file, notifications)
+                bounding_boxes = [MediaBoundingBox(result) for result in result_filter[file_id]]
+
+                # Perform inference.
+                bbox_labels = pipeline.pure_inference(storage, file, bounding_boxes)
+
+                # Add the labels determined in the inference process.
+                for i, result in enumerate(result_filter[file_id]):
+                    result.label_id = bbox_labels[i].identifier
+                    result.set_origin('user', commit=True)
+                    notifications.add(notification_manager.edit_result, result)
+
+                # yield progress
+                yield index / length, notifications
+
+                index += 1
+
+        finally:
+            if pipeline is not None:
+                pipelines.free_instance(model_root)
 
     @staticmethod
     def progress(progress: float, notifications: NotificationList):

+ 1 - 1
pycs/interfaces/MediaBoundingBox.py

@@ -13,7 +13,7 @@ class MediaBoundingBox:
         self.y = result.data['y']
         self.w = result.data['w']
         self.h = result.data['h']
-        self.label = result.label
+        self.label = result.label if hasattr(self, 'label') else None
         self.frame = result.data['frame'] if 'frame' in result.data else None
 
     def serialize(self) -> dict:

+ 14 - 0
pycs/interfaces/Pipeline.py

@@ -1,6 +1,7 @@
 from typing import List
 
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaStorage import MediaStorage
 
 
@@ -69,6 +70,19 @@ class Pipeline:
         """
         raise NotImplementedError
 
+    def pure_inference(self, storage: MediaStorage, file: MediaFile, bounding_boxes: List[MediaBoundingBox]):
+        """
+        receive a file and a list of bounding boxes and only create a
+        classification for the given bounding boxes.
+
+        :param storage: database abstraction object
+        :param file: which should be analyzed
+        :param bounding_boxes: only perform inference for the given bounding boxes
+
+        :return: labels for the given bounding boxes
+        """
+        raise NotImplementedError
+
     def fit(self, storage: MediaStorage):
         """
         receive a list of annotated media files and adapt the underlying model

+ 119 - 2
webui/src/components/media/cropped-image.vue

@@ -6,9 +6,20 @@
       <img alt="close button" src="@/assets/icons/cross.svg">
     </div>
 
-    <div v-if="src" class="image-container">
-      <h3>{{ label }}</h3>
+    <div class="label-container">
+      <h3> {{ label }} </h3>
+
+      <div  v-if="this.box.origin === 'user'"
+            ref="create_predictions"
+            class="create-predictions-icon"
+            title="create prediction for this image"
+            :class="{active: isPredictionRunning}"
+            @click="predict_cropped_image">
+        <img alt="create prediction" src="@/assets/icons/rocket.svg">
+      </div>
+    </div>
 
+    <div v-if="src" class="image-container">
       <img alt="crop" :src="src"/>
     </div>
     <div v-else>
@@ -21,6 +32,29 @@
 export default {
   name: "cropped-image",
   props: ['labels', 'file', 'box'],
+  created: function () {
+    // get data
+    this.getJobs();
+
+    // subscribe to changes
+    this.$root.socket.on('connect', this.getJobs);
+    this.$root.socket.on('create-job', this.addJob);
+    this.$root.socket.on('remove-job', this.removeJob);
+    this.$root.socket.on('edit-job', this.editJob);
+  },
+  destroyed: function () {
+    this.$root.socket.off('connect', this.getJobs);
+    this.$root.socket.off('create-job', this.addJob);
+    this.$root.socket.off('remove-job', this.removeJob);
+    this.$root.socket.off('edit-job', this.editJob);
+  },
+  data: function () {
+    return {
+      jobs: [],
+      labelSelector: false,
+      model: null
+    }
+  },
   computed: {
     src: function () {
       if (!this.box)
@@ -46,6 +80,55 @@ export default {
           return label.name;
 
       return 'Not found';
+    },
+    isPredictionRunning: function () {
+      return this.jobs.filter(j => !j.finished && j.type === 'Model Interaction').length > 0;
+    }
+  },
+  methods: {
+    getJobs: function () {
+      this.$root.socket.get('/jobs')
+          .then(response => response.json())
+          .then(jobs => {
+            this.jobs = [];
+            jobs.forEach(this.addJob)
+          });
+    },
+    addJob: function (job) {
+      for (let j of this.jobs)
+        if (j.identifier === job.identifier)
+          return;
+
+      this.jobs.push(job);
+    },
+    removeJob: function (job) {
+      for (let i = 0; i < this.jobs.length; i++) {
+        if (this.jobs[i].identifier === job.identifier) {
+          this.jobs.splice(i, 1);
+          return;
+        }
+      }
+    },
+    editJob: function (job) {
+      for (let i = 0; i < this.jobs.length; i++) {
+        if (this.jobs[i].identifier === job.identifier) {
+          this.$set(this.jobs, i, job);
+          return;
+        }
+      }
+    },
+    predict_cropped_image: function () {
+      // This shouldn't happen, since the icon is only shown if a bounding box
+      // was selected.
+      if (!this.box)
+        return;
+
+      if (!this.isPredictionRunning) {
+        // TODO then / error
+        this.$root.socket.post(`/data/${this.file.identifier}/${this.box.identifier}/predict_bounding_box`, {
+          predict: true
+        });
+      }
     }
   }
 }
@@ -78,4 +161,38 @@ export default {
   border: 2px solid;
   max-width: 100%;
 }
+
+.label-container {
+  display: flex;
+  flex-direction: row;
+  align-items: baseline;
+  justify-content: center;
+}
+
+.create-predictions-icon {
+  display: flex;
+  justify-content: center;
+  align-items: center;
+
+  cursor: pointer;
+
+  margin: 0.4rem;
+  border: 1px solid whitesmoke;
+  border-radius: 0.5rem;
+
+  width: 1.6rem;
+  height: 1.6rem;
+}
+
+.create-predictions-icon.active {
+  background-color: rgba(0, 0, 0, 0.2);
+  box-shadow: inset 0 0 5px 0 rgba(0, 0, 0, 0.2);
+}
+
+.create-predictions-icon > img {
+  max-width: 1.1rem;
+  max-height: 1.1rem;
+  filter: invert(1);
+}
+
 </style>