123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- from flask import abort
- from flask import make_response
- from flask import request
- from flask.views import View
- from pycs.database.File import File
- from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel as Predict
- from pycs.frontend.notifications.NotificationList import NotificationList
- from pycs.frontend.notifications.NotificationManager import NotificationManager
- from pycs.jobs.JobGroupBusyException import JobGroupBusyException
- from pycs.jobs.JobRunner import JobRunner
- from pycs.util.PipelineCache import PipelineCache
- class PredictFile(View):
- """
- load a model and create predictions or a given file
- """
- # pylint: disable=arguments-differ
- methods = ['POST']
- def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
- # pylint: disable=invalid-name
- self.nm = nm
- self.jobs = jobs
- self.pipelines = pipelines
- def dispatch_request(self, file_id):
- # find file
- file = File.get_or_404(file_id)
- # extract request data
- data = request.get_json(force=True)
- if not data.get('predict', False):
- abort(400, "predict flag is missing")
- # get project and model
- project = file.project
- # create job
- try:
- notifications = NotificationList(self.nm)
- self.jobs.run(project,
- 'Model Interaction',
- f'{project.name} (create predictions)',
- f'{project.id}/model-interaction',
- Predict.load_and_predict,
- self.pipelines, notifications,
- project.id, [file.id],
- progress=Predict.progress)
- except JobGroupBusyException:
- abort(400, "File prediction is already running")
- return make_response()
|