|
@@ -33,6 +33,7 @@ class PredictModel(View):
|
|
|
self.pipelines = pipelines
|
|
|
|
|
|
def dispatch_request(self, user: str, project_id):
|
|
|
+ # pylint: disable=unused-argument
|
|
|
project = Project.get_or_404(project_id)
|
|
|
|
|
|
# extract request data
|
|
@@ -129,7 +130,7 @@ class PredictModel(View):
|
|
|
notifications: NotificationList,
|
|
|
notification_manager: NotificationManager,
|
|
|
project_id: int, file_filter: List[int],
|
|
|
- result_filter: dict[int, List[Result]]):
|
|
|
+ bbox_id_filter: dict[int, List[int]], user: str):
|
|
|
"""
|
|
|
load the pipeline and call the execute function
|
|
|
|
|
@@ -139,7 +140,8 @@ class PredictModel(View):
|
|
|
:param notification_manager: notification manager
|
|
|
:param project_id: project id
|
|
|
:param file_filter: list of file ids
|
|
|
- :param result_filter: dict of file id and list of results to classify
|
|
|
+ :param bbox_id_filter: dict of file id and list of bbox_ids to classify
|
|
|
+ :param user: username of the user asking to predict the bounding box
|
|
|
:return:
|
|
|
"""
|
|
|
pipeline = None
|
|
@@ -149,32 +151,31 @@ class PredictModel(View):
|
|
|
model_root = project.model.root_folder
|
|
|
storage = MediaStorage(project_id, notifications)
|
|
|
|
|
|
- # create a list of MediaFile
|
|
|
- # Also convert dict to the same key type.
|
|
|
- length = len(file_filter)
|
|
|
-
|
|
|
-
|
|
|
# load pipeline
|
|
|
try:
|
|
|
pipeline = pipelines.load_from_root_folder(project_id, model_root)
|
|
|
|
|
|
# iterate over media files
|
|
|
index = 0
|
|
|
+ length = len(file_filter)
|
|
|
for file_id in file_filter:
|
|
|
file = project.file(file_id)
|
|
|
file = MediaFile(file, notifications)
|
|
|
- bounding_boxes = [MediaBoundingBox(result) for result in result_filter[file_id]]
|
|
|
+ bounding_boxes = [MediaBoundingBox(Result.get_or_404(bbox_id))
|
|
|
+ for bbox_id in bbox_id_filter[file_id]]
|
|
|
|
|
|
# Perform inference.
|
|
|
bbox_labels = pipeline.pure_inference(storage, file, bounding_boxes)
|
|
|
|
|
|
# Add the labels determined in the inference process.
|
|
|
- for i, result in enumerate(result_filter[file_id]):
|
|
|
- result.label_id = bbox_labels[i].identifier
|
|
|
- result.set_origin('user', origin_user=None, commit=True)
|
|
|
+ for i, bbox_id in enumerate(bbox_id_filter[file_id]):
|
|
|
+ result = Result.get_or_404(bbox_id)
|
|
|
+ result.set_label(bbox_labels[i].identifier, commit=True)
|
|
|
+ result.set_origin('user', origin_user=user, commit=True)
|
|
|
notifications.add(notification_manager.edit_result, result)
|
|
|
|
|
|
- # yield progress
|
|
|
+ # commit changes and yield progress
|
|
|
+ db.session.commit()
|
|
|
yield index / length, notifications
|
|
|
|
|
|
index += 1
|