from typing import Any from flask import abort from flask import make_response from flask import request from flask.views import View from pycs import app from pycs import db from pycs.database.Database import Database from pycs.database.Project import Project from pycs.frontend.notifications.NotificationList import NotificationList from pycs.frontend.notifications.NotificationManager import NotificationManager from pycs.interfaces.MediaFile import MediaFile from pycs.interfaces.MediaStorage import MediaStorage from pycs.jobs.JobGroupBusyException import JobGroupBusyException from pycs.jobs.JobRunner import JobRunner from pycs.util.PipelineCache import PipelineCache class PredictModel(View): """ load a model and create predictions """ # pylint: disable=arguments-differ methods = ['POST'] def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache): # pylint: disable=invalid-name self.nm = nm self.jobs = jobs self.pipelines = pipelines def dispatch_request(self, project_id): # extract request data data = request.get_json(force=True) if data.get('predict') not in ['all', 'new']: return abort(400) # find project project = Project.query.get(project_id) if project is None: return abort(404) # create job try: notifications = NotificationList(self.nm) self.jobs.run(project, 'Model Interaction', f'{project.name} (create predictions)', f'{project.name}/model-interaction', PredictModel.load_and_predict, self.pipelines, notifications, project.id, data['predict'], progress=self.progress) except JobGroupBusyException: return abort(400) return make_response() @staticmethod def load_and_predict(pipelines: PipelineCache, notifications: NotificationList, project_id: int, file_filter: Any): pipeline = None # create new database instance try: project = Project.query.get(project_id) model = project.model storage = MediaStorage(project_id, notifications) # create a list of MediaFile if isinstance(file_filter, str): if file_filter == 'new': length = project.count_files_without_results() files = map(lambda f: MediaFile(f, notifications), project.files_without_results()) else: length = project.count_files() files = map(lambda f: MediaFile(f, notifications), project.files.all()) else: files = map(lambda f: MediaFile(project.file(f.id), notifications), file_filter) length = len(file_filter) # load pipeline try: pipeline = pipelines.load_from_root_folder(project, model.root_folder) # iterate over files index = 0 for file in files: # remove old predictions file.remove_predictions() # create new predictions pipeline.execute(storage, file) # commit changes and yield progress db.session.commit() yield index / length, notifications index += 1 except Exception as e: import traceback traceback.print_exc() app.logger.warning(f"Pipeline Error #2: {e}") finally: if pipeline is not None: pipelines.free_instance(model.root_folder) except Exception as e: import traceback traceback.print_exc() app.logger.warning(f"Pipeline Error #1: {e}") @staticmethod def progress(progress: float, notifications: NotificationList): notifications.fire() return progress