from flask import make_response, request, abort from flask.views import View from pycs.database.Database import Database from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel from pycs.frontend.notifications.NotificationManager import NotificationManager from pycs.interfaces.MediaFile import MediaFile from pycs.jobs.JobGroupBusyException import JobGroupBusyException from pycs.jobs.JobRunner import JobRunner class PredictFile(View): """ load a model and create predictions or a given file """ # pylint: disable=arguments-differ methods = ['POST'] def __init__(self, db: Database, nm: NotificationManager, jobs: JobRunner): # pylint: disable=invalid-name self.db = db self.nm = nm self.jobs = jobs def dispatch_request(self, file_id): # extract request data data = request.get_json(force=True) if 'predict' not in data or data['predict'] is not True: return abort(400) # find file file = self.db.file(file_id) if file is None: return abort(404) media_file = MediaFile(file) # get project and model project = file.project() model = project.model() # create job def store(index, length, result): with self.db: for remove in file.results(): if remove.origin == 'pipeline': remove.remove() self.nm.remove_result(remove) for entry in result: file_type = entry['type'] del entry['type'] if 'label' in entry: label = entry['label'] del entry['label'] else: label = None if file_type == 'labeled-image': for remove in file.results(): remove.remove() self.nm.remove_result(remove) created = file.create_result('pipeline', file_type, label, entry) self.nm.create_result(created) return (index + 1) / length try: self.jobs.run(project, 'Model Interaction', f'{project.name} (create predictions)', f'{project.name}/model-interaction', PredictModel.load_and_predict, model, [media_file], progress=store) except JobGroupBusyException: return abort(400) return make_response()