6
0

PredictBoundingBox.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from flask import abort
  2. from flask import make_response
  3. from flask import request
  4. from flask.views import View
  5. from pycs.database.Result import Result
  6. from pycs.database.File import File
  7. from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel as Predict
  8. from pycs.frontend.notifications.NotificationList import NotificationList
  9. from pycs.frontend.notifications.NotificationManager import NotificationManager
  10. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  11. from pycs.jobs.JobRunner import JobRunner
  12. from pycs.util.PipelineCache import PipelineCache
  13. class PredictBoundingBox(View):
  14. """
  15. load a model and create predictions or a given file
  16. """
  17. # pylint: disable=arguments-differ
  18. methods = ['POST']
  19. def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  20. # pylint: disable=invalid-name
  21. self.nm = nm
  22. self.jobs = jobs
  23. self.pipelines = pipelines
  24. def dispatch_request(self, user: str, file_id, bbox_id):
  25. # find file and result (=bounding box)
  26. # We will later need the result to get (x,y,w,h). Here, we just check
  27. # whether the result is valid.
  28. file = File.get_or_404(file_id)
  29. Result.get_or_404(bbox_id)
  30. # extract request data
  31. data = request.get_json(force=True)
  32. if not data.get('predict', False):
  33. abort(400, "predict flag is missing")
  34. # get project and model
  35. project = file.project
  36. # create job
  37. try:
  38. notifications = NotificationList(self.nm)
  39. self.jobs.run(project,
  40. 'Model Interaction',
  41. f'{project.name} (create predictions)',
  42. f'{project.id}/model-interaction',
  43. Predict.load_and_pure_inference,
  44. self.pipelines, notifications, self.nm,
  45. project.id, [file.id], {file.id: [bbox_id]},
  46. user, progress=Predict.progress)
  47. except JobGroupBusyException:
  48. abort(400, "File prediction is already running")
  49. return make_response()