PredictModel.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from contextlib import closing
  2. from flask import make_response, request, abort
  3. from flask.views import View
  4. from pycs.database.Database import Database
  5. from pycs.frontend.notifications.NotificationManager import NotificationManager
  6. from pycs.interfaces.MediaFile import MediaFile
  7. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  8. from pycs.jobs.JobRunner import JobRunner
  9. from pycs.util.PipelineUtil import load_from_root_folder as load_pipeline
  10. class PredictModel(View):
  11. """
  12. load a model and create predictions
  13. """
  14. # pylint: disable=arguments-differ
  15. methods = ['POST']
  16. def __init__(self, db: Database, nm: NotificationManager, jobs: JobRunner):
  17. # pylint: disable=invalid-name
  18. self.db = db
  19. self.nm = nm
  20. self.jobs = jobs
  21. def dispatch_request(self, project_id):
  22. # extract request data
  23. data = request.get_json(force=True)
  24. if 'predict' not in data or data['predict'] not in ['all', 'new']:
  25. return abort(400)
  26. # find project
  27. project = self.db.project(project_id)
  28. if project is None:
  29. return abort(404)
  30. # get model
  31. model = project.model()
  32. # get data and results
  33. if data['predict'] == 'new':
  34. files = project.files_without_results()
  35. else:
  36. files = project.files()
  37. objects = list(map(MediaFile, files))
  38. # create job
  39. def store(index, length, result):
  40. # get file from list
  41. file = files[index]
  42. # start transaction
  43. with self.db:
  44. # remove current results from file
  45. for remove in file.results():
  46. if remove.origin == 'pipeline':
  47. remove.remove()
  48. self.nm.remove_result(remove)
  49. # iterate over result entries
  50. for entry in result:
  51. # extract entry type
  52. entry_type = entry['type']
  53. del entry['type']
  54. # update file collection
  55. if entry_type == 'collection':
  56. file.set_collection_by_reference(entry['reference'])
  57. self.nm.edit_file(file)
  58. continue
  59. # extract label from entry
  60. if 'label' in entry:
  61. label = entry['label']
  62. del entry['label']
  63. else:
  64. label = None
  65. # if entry_type == 'labeled-image':
  66. # for remove in file.results():
  67. # remove.remove()
  68. # self.nm.remove_result(remove)
  69. # add result
  70. created = files[index].create_result('pipeline', entry_type, label, entry)
  71. self.nm.create_result(created)
  72. return (index + 1) / length
  73. try:
  74. self.jobs.run(project,
  75. 'Model Interaction',
  76. f'{project.name} (create predictions)',
  77. f'{project.name}/model-interaction',
  78. self.load_and_predict, model, objects,
  79. progress=store)
  80. except JobGroupBusyException:
  81. return abort(400)
  82. return make_response()
  83. @staticmethod
  84. def load_and_predict(model, files):
  85. with closing(load_pipeline(model.root_folder)) as pipeline:
  86. length = len(files)
  87. for index in range(length):
  88. result = pipeline.execute(files[index])
  89. yield index, length, result