PredictModel.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from typing import List
  2. from typing import Union
  3. from flask import abort
  4. from flask import make_response
  5. from flask import request
  6. from flask.views import View
  7. from pycs import app
  8. from pycs import db
  9. from pycs.database.File import File
  10. from pycs.database.Project import Project
  11. from pycs.frontend.notifications.NotificationList import NotificationList
  12. from pycs.frontend.notifications.NotificationManager import NotificationManager
  13. from pycs.interfaces.MediaFile import MediaFile
  14. from pycs.interfaces.MediaStorage import MediaStorage
  15. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  16. from pycs.jobs.JobRunner import JobRunner
  17. from pycs.util.PipelineCache import PipelineCache
  18. class PredictModel(View):
  19. """
  20. load a model and create predictions
  21. """
  22. # pylint: disable=arguments-differ
  23. methods = ['POST']
  24. def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  25. # pylint: disable=invalid-name
  26. self.nm = nm
  27. self.jobs = jobs
  28. self.pipelines = pipelines
  29. def dispatch_request(self, project_id):
  30. # extract request data
  31. data = request.get_json(force=True)
  32. if data.get('predict') not in ['all', 'new']:
  33. return abort(400)
  34. # find project
  35. project = Project.query.get(project_id)
  36. if project is None:
  37. return abort(404)
  38. # create job
  39. try:
  40. notifications = NotificationList(self.nm)
  41. self.jobs.run(project,
  42. 'Model Interaction',
  43. f'{project.name} (create predictions)',
  44. f'{project.name}/model-interaction',
  45. PredictModel.load_and_predict,
  46. self.pipelines, notifications,
  47. project.id, data['predict'],
  48. progress=self.progress)
  49. except JobGroupBusyException:
  50. return abort(400)
  51. return make_response()
  52. @staticmethod
  53. def load_and_predict(pipelines: PipelineCache, notifications: NotificationList,
  54. project_id: int, file_filter: Union[str, List[File]]):
  55. """
  56. load the pipeline and call the execute function
  57. :param pipelines: pipeline cache
  58. :param notifications: notification object
  59. :param project_id: project id
  60. :param file_filter: list of files or 'new' / 'all'
  61. :return:
  62. """
  63. pipeline = None
  64. try:
  65. project = Project.query.get(project_id)
  66. model = project.model
  67. storage = MediaStorage(project_id, notifications)
  68. # create a list of MediaFile
  69. if isinstance(file_filter, str):
  70. if file_filter == 'new':
  71. length = project.count_files_without_results()
  72. files = map(lambda f: MediaFile(f, notifications),
  73. project.files_without_results())
  74. else:
  75. length = project.count_files()
  76. files = map(lambda f: MediaFile(f, notifications),
  77. project.files.all())
  78. else:
  79. files = map(lambda f: MediaFile(project.file(f.id), notifications),
  80. file_filter)
  81. length = len(file_filter)
  82. # load pipeline
  83. try:
  84. pipeline = pipelines.load_from_root_folder(project, model.root_folder)
  85. # iterate over files
  86. index = 0
  87. for file in files:
  88. # remove old predictions
  89. file.remove_predictions()
  90. # create new predictions
  91. pipeline.execute(storage, file)
  92. # commit changes and yield progress
  93. db.session.commit()
  94. yield index / length, notifications
  95. index += 1
  96. except Exception as e:
  97. import traceback
  98. traceback.print_exc()
  99. app.logger.error(f"Pipeline Error #2: {e}")
  100. finally:
  101. if pipeline is not None:
  102. pipelines.free_instance(model.root_folder)
  103. except Exception as e:
  104. import traceback
  105. traceback.print_exc()
  106. app.logger.error(f"Pipeline Error #1: {e}")
  107. @staticmethod
  108. def progress(progress: float, notifications: NotificationList):
  109. """
  110. fire notifications from the correct thread
  111. :param progress: [0, 1]
  112. :param notifications: Notificationlist
  113. :return: progress
  114. """
  115. notifications.fire()
  116. return progress