6
0

PredictModel.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from typing import List
  2. from typing import Union
  3. from flask import abort
  4. from flask import make_response
  5. from flask import request
  6. from flask.views import View
  7. from pycs import app
  8. from pycs import db
  9. from pycs.database.File import File
  10. from pycs.database.Project import Project
  11. from pycs.frontend.notifications.NotificationList import NotificationList
  12. from pycs.interfaces.MediaFile import MediaFile
  13. from pycs.interfaces.MediaStorage import MediaStorage
  14. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  15. from pycs.jobs.JobRunner import JobRunner
  16. from pycs.util.PipelineCache import PipelineCache
  17. class PredictModel(View):
  18. """
  19. load a model and create predictions
  20. """
  21. # pylint: disable=arguments-differ
  22. methods = ['POST']
  23. def __init__(self, pipelines: PipelineCache):
  24. # pylint: disable=invalid-name
  25. self.pipelines = pipelines
  26. def dispatch_request(self, project_id):
  27. # extract request data
  28. data = request.get_json(force=True)
  29. if data.get('predict') not in ['all', 'new']:
  30. return abort(400)
  31. # find project
  32. project = Project.query.get(project_id)
  33. if project is None:
  34. return abort(404)
  35. # create job
  36. try:
  37. notifications = NotificationList()
  38. JobRunner.Run(project,
  39. 'Model Interaction',
  40. f'{project.name} (create predictions)',
  41. f'{project.name}/model-interaction',
  42. PredictModel.load_and_predict,
  43. self.pipelines, notifications,
  44. project.id, data['predict'],
  45. progress=self.progress)
  46. except JobGroupBusyException:
  47. return abort(400)
  48. return make_response()
  49. @staticmethod
  50. def load_and_predict(pipelines: PipelineCache, notifications: NotificationList,
  51. project_id: int, file_filter: Union[str, List[File]]):
  52. """
  53. load the pipeline and call the execute function
  54. :param pipelines: pipeline cache
  55. :param notifications: notification object
  56. :param project_id: project id
  57. :param file_filter: list of files or 'new' / 'all'
  58. :return:
  59. """
  60. pipeline = None
  61. try:
  62. project = Project.query.get(project_id)
  63. model = project.model
  64. storage = MediaStorage(project_id, notifications)
  65. # create a list of MediaFile
  66. if isinstance(file_filter, str):
  67. if file_filter == 'new':
  68. length = project.count_files_without_results()
  69. files = map(lambda f: MediaFile(f, notifications),
  70. project.files_without_results())
  71. else:
  72. length = project.count_files()
  73. files = map(lambda f: MediaFile(f, notifications),
  74. project.files.all())
  75. else:
  76. files = map(lambda f: MediaFile(project.file(f.id), notifications),
  77. file_filter)
  78. length = len(file_filter)
  79. # load pipeline
  80. try:
  81. pipeline = pipelines.load_from_root_folder(project, model.root_folder)
  82. # iterate over files
  83. index = 0
  84. for file in files:
  85. # remove old predictions
  86. file.remove_predictions()
  87. # create new predictions
  88. pipeline.execute(storage, file)
  89. # commit changes and yield progress
  90. db.session.commit()
  91. yield index / length, notifications
  92. index += 1
  93. except Exception as e:
  94. import traceback
  95. traceback.print_exc()
  96. app.logger.error(f"Pipeline Error #2: {e}")
  97. finally:
  98. if pipeline is not None:
  99. pipelines.free_instance(model.root_folder)
  100. except Exception as e:
  101. import traceback
  102. traceback.print_exc()
  103. app.logger.error(f"Pipeline Error #1: {e}")
  104. @staticmethod
  105. def progress(progress: float, notifications: NotificationList):
  106. """
  107. fire notifications from the correct thread
  108. :param progress: [0, 1]
  109. :param notifications: Notificationlist
  110. :return: progress
  111. """
  112. notifications.fire()
  113. return progress