PredictFile.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from flask import abort
  2. from flask import make_response
  3. from flask import request
  4. from flask.views import View
  5. from pycs.database.File import File
  6. from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel as Predict
  7. from pycs.frontend.notifications.NotificationList import NotificationList
  8. from pycs.frontend.notifications.NotificationManager import NotificationManager
  9. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  10. from pycs.jobs.JobRunner import JobRunner
  11. from pycs.util.PipelineCache import PipelineCache
  12. class PredictFile(View):
  13. """
  14. load a model and create predictions or a given file
  15. """
  16. # pylint: disable=arguments-differ
  17. methods = ['POST']
  18. def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  19. # pylint: disable=invalid-name
  20. self.nm = nm
  21. self.jobs = jobs
  22. self.pipelines = pipelines
  23. def dispatch_request(self, file_id):
  24. # find file
  25. file = File.get_or_404(file_id)
  26. # extract request data
  27. data = request.get_json(force=True)
  28. if not data.get('predict', False):
  29. abort(400, "predict flag is missing")
  30. # get project and model
  31. project = file.project
  32. # create job
  33. try:
  34. notifications = NotificationList(self.nm)
  35. self.jobs.run(project,
  36. 'Model Interaction',
  37. f'{project.name} (create predictions)',
  38. f'{project.id}/model-interaction',
  39. Predict.load_and_predict,
  40. self.pipelines, notifications,
  41. project.id, [file.id],
  42. progress=Predict.progress)
  43. except JobGroupBusyException:
  44. abort(400, "File prediction is already running")
  45. return make_response()