PredictFile.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from flask import make_response, request, abort
  2. from flask.views import View
  3. from pycs.database.Database import Database
  4. from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel
  5. from pycs.frontend.notifications.NotificationManager import NotificationManager
  6. from pycs.interfaces.MediaFile import MediaFile
  7. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  8. from pycs.jobs.JobRunner import JobRunner
  9. class PredictFile(View):
  10. """
  11. load a model and create predictions or a given file
  12. """
  13. # pylint: disable=arguments-differ
  14. methods = ['POST']
  15. def __init__(self, db: Database, nm: NotificationManager, jobs: JobRunner):
  16. # pylint: disable=invalid-name
  17. self.db = db
  18. self.nm = nm
  19. self.jobs = jobs
  20. def dispatch_request(self, file_id):
  21. # extract request data
  22. data = request.get_json(force=True)
  23. if 'predict' not in data or data['predict'] is not True:
  24. return abort(400)
  25. # find file
  26. file = self.db.file(file_id)
  27. if file is None:
  28. return abort(404)
  29. media_file = MediaFile(file)
  30. # get project and model
  31. project = file.project()
  32. model = project.model()
  33. # create job
  34. def store(index, length, result):
  35. with self.db:
  36. for remove in file.results():
  37. if remove.origin == 'pipeline':
  38. remove.remove()
  39. self.nm.remove_result(remove)
  40. for entry in result:
  41. file_type = entry['type']
  42. del entry['type']
  43. if 'label' in entry:
  44. label = entry['label']
  45. del entry['label']
  46. else:
  47. label = None
  48. if file_type == 'labeled-image':
  49. for remove in file.results():
  50. remove.remove()
  51. self.nm.remove_result(remove)
  52. created = file.create_result('pipeline', file_type, label, entry)
  53. self.nm.create_result(created)
  54. return (index + 1) / length
  55. try:
  56. self.jobs.run(project,
  57. 'Model Interaction',
  58. f'{project.name} (create predictions)',
  59. f'{project.name}/model-interaction',
  60. PredictModel.load_and_predict, model, [media_file],
  61. progress=store)
  62. except JobGroupBusyException:
  63. return abort(400)
  64. return make_response()