6
0

PredictFile.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from flask import make_response, request, abort
  2. from flask.views import View
  3. from pycs.database.Database import Database
  4. from pycs.database.File import File
  5. from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel
  6. from pycs.frontend.notifications.NotificationList import NotificationList
  7. from pycs.frontend.notifications.NotificationManager import NotificationManager
  8. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  9. from pycs.jobs.JobRunner import JobRunner
  10. from pycs.util.PipelineCache import PipelineCache
  11. class PredictFile(View):
  12. """
  13. load a model and create predictions or a given file
  14. """
  15. # pylint: disable=arguments-differ
  16. methods = ['POST']
  17. def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  18. # pylint: disable=invalid-name
  19. self.nm = nm
  20. self.jobs = jobs
  21. self.pipelines = pipelines
  22. def dispatch_request(self, file_id):
  23. # extract request data
  24. data = request.get_json(force=True)
  25. if 'predict' not in data or data['predict'] is not True:
  26. return abort(400)
  27. # find file
  28. file = File.query.get(file_id)
  29. if file is None:
  30. return abort(404)
  31. # get project and model
  32. project = file.project
  33. # create job
  34. try:
  35. notifications = NotificationList(self.nm)
  36. self.jobs.run(project,
  37. 'Model Interaction',
  38. f'{project.name} (create predictions)',
  39. f'{project.name}/model-interaction',
  40. PredictModel.load_and_predict,
  41. self.pipelines, notifications, project.id, [file],
  42. progress=PredictModel.progress)
  43. except JobGroupBusyException:
  44. return abort(400)
  45. return make_response()