6
0

PredictModel.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from contextlib import closing
  2. from flask import make_response, request, abort
  3. from flask.views import View
  4. from pycs.database.Database import Database
  5. from pycs.frontend.notifications.NotificationManager import NotificationManager
  6. from pycs.interfaces.MediaFile import MediaFile
  7. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  8. from pycs.jobs.JobRunner import JobRunner
  9. from pycs.util.PipelineUtil import load_from_root_folder as load_pipeline
  10. class PredictModel(View):
  11. """
  12. load a model and create predictions
  13. """
  14. # pylint: disable=arguments-differ
  15. methods = ['POST']
  16. def __init__(self, db: Database, nm: NotificationManager, jobs: JobRunner):
  17. # pylint: disable=invalid-name
  18. self.db = db
  19. self.nm = nm
  20. self.jobs = jobs
  21. def dispatch_request(self, project_id):
  22. # extract request data
  23. data = request.get_json(force=True)
  24. if 'predict' not in data or data['predict'] not in ['all', 'new']:
  25. return abort(400)
  26. # find project
  27. project = self.db.project(project_id)
  28. if project is None:
  29. return abort(404)
  30. # get model
  31. model = project.model()
  32. # get data and results
  33. if data['predict'] == 'new':
  34. files = project.files_without_results()
  35. else:
  36. files = project.files()
  37. objects = list(map(MediaFile, files))
  38. # create job
  39. def store(index, length, result):
  40. with self.db:
  41. for remove in files[index].results():
  42. if remove.origin == 'pipeline':
  43. remove.remove()
  44. self.nm.remove_result(remove)
  45. for entry in result:
  46. file_type = entry['type']
  47. del entry['type']
  48. if 'label' in entry:
  49. label = entry['label']
  50. del entry['label']
  51. else:
  52. label = None
  53. if file_type == 'labeled-image':
  54. for remove in files[index].results():
  55. remove.remove()
  56. self.nm.remove_result(remove)
  57. created = files[index].create_result('pipeline', file_type, label, entry)
  58. self.nm.create_result(created)
  59. return (index + 1) / length
  60. try:
  61. self.jobs.run(project,
  62. 'Model Interaction',
  63. f'{project.name} (create predictions)',
  64. f'{project.name}/model-interaction',
  65. self.load_and_predict, model, objects,
  66. progress=store)
  67. except JobGroupBusyException:
  68. return abort(400)
  69. return make_response()
  70. @staticmethod
  71. def load_and_predict(model, files):
  72. with closing(load_pipeline(model.root_folder)) as pipeline:
  73. length = len(files)
  74. for index in range(length):
  75. result = pipeline.execute(files[index])
  76. yield index, length, result