|
@@ -8,9 +8,11 @@ from flask.views import View
|
|
|
|
|
|
from pycs import db
|
|
|
from pycs.database.Project import Project
|
|
|
+from pycs.database.Result import Result
|
|
|
from pycs.frontend.notifications.NotificationList import NotificationList
|
|
|
from pycs.frontend.notifications.NotificationManager import NotificationManager
|
|
|
from pycs.interfaces.MediaFile import MediaFile
|
|
|
+from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
|
|
|
from pycs.interfaces.MediaStorage import MediaStorage
|
|
|
from pycs.jobs.JobGroupBusyException import JobGroupBusyException
|
|
|
from pycs.jobs.JobRunner import JobRunner
|
|
@@ -122,6 +124,63 @@ class PredictModel(View):
|
|
|
if pipeline is not None:
|
|
|
pipelines.free_instance(model_root)
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def load_and_pure_inference(pipelines: PipelineCache,
|
|
|
+ notifications: NotificationList,
|
|
|
+ nm: NotificationManager,
|
|
|
+ project_id: int, file_filter: List[int], result_filter: dict[int, List[Result]]):
|
|
|
+ """
|
|
|
+ load the pipeline and call the execute function
|
|
|
+
|
|
|
+ :param database: database object
|
|
|
+ :param pipelines: pipeline cache
|
|
|
+ :param notifications: notification object
|
|
|
+ :param nm: notification manager
|
|
|
+ :param project_id: project id
|
|
|
+ :param file_filter: list of file ids
|
|
|
+ :param result_filter: dict of file id and list of results (bounding boxes) to classify
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ pipeline = None
|
|
|
+
|
|
|
+ # create new database instance
|
|
|
+ project = Project.query.get(project_id)
|
|
|
+ model_root = project.model.root_folder
|
|
|
+ storage = MediaStorage(project_id, notifications)
|
|
|
+
|
|
|
+ # create a list of MediaFile
|
|
|
+ # Also convert dict to the same key type.
|
|
|
+ length = len(file_filter)
|
|
|
+
|
|
|
+
|
|
|
+ # load pipeline
|
|
|
+ try:
|
|
|
+ pipeline = pipelines.load_from_root_folder(project_id, model_root)
|
|
|
+
|
|
|
+ # iterate over media files
|
|
|
+ index = 0
|
|
|
+ for file_id in file_filter:
|
|
|
+ file = project.file(file_id)
|
|
|
+ file = MediaFile(file, notifications)
|
|
|
+ bounding_boxes = [MediaBoundingBox(result) for result in result_filter[file_id]]
|
|
|
+
|
|
|
+ # Perform inference.
|
|
|
+ bbox_labels = pipeline.pure_inference(storage, file, bounding_boxes)
|
|
|
+
|
|
|
+ # Add the labels determined in the inference process.
|
|
|
+ for i, result in enumerate(result_filter[file_id]):
|
|
|
+ result.label_id = bbox_labels[i].identifier
|
|
|
+ result.set_origin('user', commit=True)
|
|
|
+ notifications.add(nm.edit_result, result)
|
|
|
+
|
|
|
+ # yield progress
|
|
|
+ yield index / length, notifications
|
|
|
+
|
|
|
+ index += 1
|
|
|
+
|
|
|
+ finally:
|
|
|
+ if pipeline is not None:
|
|
|
+ pipelines.free_instance(model_root)
|
|
|
|
|
|
@staticmethod
|
|
|
def progress(progress: float, notifications: NotificationList):
|