from typing import List from typing import Union from flask import abort from flask import make_response from flask import request from flask.views import View from pycs import db from pycs.database.Project import Project from pycs.database.Result import Result from pycs.frontend.notifications.NotificationList import NotificationList from pycs.frontend.notifications.NotificationManager import NotificationManager from pycs.interfaces.MediaFile import MediaFile from pycs.interfaces.MediaLabel import MediaLabel from pycs.interfaces.MediaBoundingBox import MediaBoundingBox from pycs.interfaces.MediaStorage import MediaStorage from pycs.jobs.JobGroupBusyException import JobGroupBusyException from pycs.jobs.JobRunner import JobRunner from pycs.util.PipelineCache import PipelineCache class PredictModel(View): """ load a model and create predictions """ # pylint: disable=arguments-differ methods = ['POST'] def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache): # pylint: disable=invalid-name self.nm = nm self.jobs = jobs self.pipelines = pipelines def dispatch_request(self, project_id): project = Project.get_or_404(project_id) # extract request data data = request.get_json(force=True) predict = data.get('predict') if predict is None: abort(400, "predict argument is missing") if predict not in ['all', 'new']: abort(400, "predict must be either 'all' or 'new'") # create job try: notifications = NotificationList(self.nm) self.jobs.run(project, 'Model Interaction', f'{project.name} (create predictions)', f'{project.id}/model-interaction', PredictModel.load_and_predict, self.pipelines, notifications, project.id, predict, progress=self.progress) except JobGroupBusyException: abort(400, "Model prediction is already running") return make_response() @staticmethod def load_and_predict(pipelines: PipelineCache, notifications: NotificationList, project_id: int, file_filter: Union[str, List[int]]): """ load the pipeline and call the execute function :param database: database object :param pipelines: pipeline cache :param notifications: notification object :param project_id: project id :param file_filter: list of file ids or 'new' / 'all' :return: """ pipeline = None # create new database instance project = Project.query.get(project_id) model_root = project.model.root_folder storage = MediaStorage(project_id, notifications) # create a list of MediaFile if isinstance(file_filter, str): if file_filter == 'new': files = project.files_without_results() length = project.count_files_without_results() else: files = project.files.all() length = project.files.count() else: files = [project.file(identifier) for identifier in file_filter] length = len(files) media_files = map(lambda f: MediaFile(f, notifications), files) # load pipeline try: pipeline = pipelines.load_from_root_folder(project_id, model_root) # iterate over media files index = 0 for file in media_files: # remove old predictions file.remove_predictions() # create new predictions pipeline.execute(storage, file) # commit changes and yield progress db.session.commit() yield index / length, notifications index += 1 finally: if pipeline is not None: pipelines.free_instance(model_root) @staticmethod def load_and_pure_inference(pipelines: PipelineCache, notifications: NotificationList, notification_manager: NotificationManager, project_id: int, file_filter: List[int], result_filter: dict[int, List[Result]]): """ load the pipeline and call the execute function :param database: database object :param pipelines: pipeline cache :param notifications: notification object :param notification_manager: notification manager :param project_id: project id :param file_filter: list of file ids :param result_filter: dict of file id and list of results to classify :return: """ pipeline = None # create new database instance project = Project.query.get(project_id) model_root = project.model.root_folder storage = MediaStorage(project_id, notifications) # create a list of MediaFile # Also convert dict to the same key type. length = len(file_filter) # load pipeline try: pipeline = pipelines.load_from_root_folder(project_id, model_root) # iterate over media files index = 0 for file_id in file_filter: file = project.file(file_id) file = MediaFile(file, notifications) bounding_boxes = [MediaBoundingBox(result) for result in result_filter[file_id]] # Perform inference. bbox_labels = pipeline.pure_inference(storage, file, bounding_boxes) # Add the labels determined in the inference process. for i, result in enumerate(result_filter[file_id]): bbox_label = bbox_labels[i] if isinstance(bbox_label, MediaLabel): result.label_id = bbox_label.identifier result.set_origin('user', commit=True) notifications.add(notification_manager.edit_result, result) # yield progress yield index / length, notifications index += 1 finally: if pipeline is not None: pipelines.free_instance(model_root) @staticmethod def progress(progress: float, notifications: NotificationList): """ fire notifications from the correct thread :param progress: [0, 1] :param notifications: Notificationlist :return: progress """ notifications.fire() return progress