PredictFile.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from flask import make_response, request, abort
  2. from flask.views import View
  3. from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel as Predict
  4. from pycs.frontend.notifications.NotificationList import NotificationList
  5. from pycs.frontend.notifications.NotificationManager import NotificationManager
  6. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  7. from pycs.jobs.JobRunner import JobRunner
  8. from pycs.util.PipelineCache import PipelineCache
  9. class PredictFile(View):
  10. """
  11. load a model and create predictions or a given file
  12. """
  13. # pylint: disable=arguments-differ
  14. methods = ['POST']
  15. def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  16. # pylint: disable=invalid-name
  17. self.nm = nm
  18. self.jobs = jobs
  19. self.pipelines = pipelines
  20. def dispatch_request(self, file_id):
  21. # extract request data
  22. data = request.get_json(force=True)
  23. if not data.get('predict', False):
  24. abort(400, "predict flag is missing")
  25. # find file
  26. file = File.get_or_404(file_id)
  27. # get project and model
  28. project = file.project
  29. # create job
  30. try:
  31. notifications = NotificationList(self.nm)
  32. self.jobs.run(project,
  33. 'Model Interaction',
  34. f'{project.name} (create predictions)',
  35. f'{project.id}/model-interaction',
  36. Predict.load_and_predict,
  37. self.pipelines, notifications, project.id, [file.id],
  38. progress=Predict.progress)
  39. except JobGroupBusyException:
  40. abort(400, "File prediction is already running")
  41. return make_response()