PredictBoundingBox.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from flask import abort
  2. from flask import make_response
  3. from flask import request
  4. from flask.views import View
  5. from pycs.database.Result import Result
  6. from pycs.database.File import File
  7. from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel as Predict
  8. from pycs.frontend.notifications.NotificationList import NotificationList
  9. from pycs.frontend.notifications.NotificationManager import NotificationManager
  10. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  11. from pycs.jobs.JobRunner import JobRunner
  12. from pycs.util.PipelineCache import PipelineCache
  13. class PredictBoundingBox(View):
  14. """
  15. load a model and create predictions or a given file
  16. """
  17. # pylint: disable=arguments-differ
  18. methods = ['POST']
  19. def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  20. # pylint: disable=invalid-name
  21. self.nm = nm
  22. self.jobs = jobs
  23. self.pipelines = pipelines
  24. def dispatch_request(self, file_id, bbox_id):
  25. # find file and result (=bounding box)
  26. # We need the result to get (x,y,w,h)
  27. file = File.get_or_404(file_id)
  28. result = Result.get_or_404(bbox_id)
  29. # extract request data
  30. data = request.get_json(force=True)
  31. if not data.get('predict', False):
  32. abort(400, "predict flag is missing")
  33. # get project and model
  34. project = file.project
  35. # create job
  36. try:
  37. notifications = NotificationList(self.nm)
  38. self.jobs.run(project,
  39. 'Model Interaction',
  40. f'{project.name} (create predictions)',
  41. f'{project.id}/model-interaction',
  42. Predict.load_and_pure_inference,
  43. self.pipelines, notifications, self.nm,
  44. project.id, [file.id], {file.id: [result]},
  45. progress=Predict.progress)
  46. except JobGroupBusyException:
  47. abort(400, "File prediction is already running")
  48. return make_response()