PredictModel.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from typing import List
  2. from typing import Union
  3. from flask import abort
  4. from flask import make_response
  5. from flask import request
  6. from flask.views import View
  7. from pycs import db
  8. from pycs.database.File import File
  9. from pycs.database.Project import Project
  10. from pycs.frontend.notifications.NotificationList import NotificationList
  11. from pycs.frontend.notifications.NotificationManager import NotificationManager
  12. from pycs.interfaces.MediaFile import MediaFile
  13. from pycs.interfaces.MediaStorage import MediaStorage
  14. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  15. from pycs.jobs.JobRunner import JobRunner
  16. from pycs.util.PipelineCache import PipelineCache
  17. class PredictModel(View):
  18. """
  19. load a model and create predictions
  20. """
  21. # pylint: disable=arguments-differ
  22. methods = ['POST']
  23. def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  24. # pylint: disable=invalid-name
  25. self.nm = nm
  26. self.jobs = jobs
  27. self.pipelines = pipelines
  28. def dispatch_request(self, project_id):
  29. # extract request data
  30. data = request.get_json(force=True)
  31. predict = data.get('predict')
  32. if predict is None:
  33. abort(400, "predict argument is missing")
  34. if predict not in ['all', 'new']:
  35. abort(400, "predict must be either 'all' or 'new'")
  36. # find project
  37. project = Project.get_or_404(project_id)
  38. # create job
  39. try:
  40. notifications = NotificationList(self.nm)
  41. self.jobs.run(project,
  42. 'Model Interaction',
  43. f'{project.name} (create predictions)',
  44. f'{project.id}/model-interaction',
  45. PredictModel.load_and_predict,
  46. self.pipelines, notifications,
  47. project.id, predict,
  48. progress=self.progress)
  49. except JobGroupBusyException:
  50. abort(400, "Model prediction is already running")
  51. return make_response()
  52. @staticmethod
  53. def load_and_predict(pipelines: PipelineCache,
  54. notifications: NotificationList,
  55. project_id: int, file_filter: Union[str, List[int]]):
  56. """
  57. load the pipeline and call the execute function
  58. :param database: database object
  59. :param pipelines: pipeline cache
  60. :param notifications: notification object
  61. :param project_id: project id
  62. :param file_filter: list of file ids or 'new' / 'all'
  63. :return:
  64. """
  65. pipeline = None
  66. # create new database instance
  67. project = Project.query.get(project_id)
  68. model = project.model
  69. storage = MediaStorage(project_id, notifications)
  70. # create a list of MediaFile
  71. if isinstance(file_filter, str):
  72. if file_filter == 'new':
  73. files = project.files_without_results()
  74. length = project.count_files_without_results()
  75. else:
  76. files = project.files.all()
  77. length = project.files.count()
  78. else:
  79. files = [project.file(identifier) for identifier in file_filter]
  80. length = len(files)
  81. media_files = map(lambda f: MediaFile(f, notifications), files)
  82. # load pipeline
  83. try:
  84. pipeline = pipelines.load_from_root_folder(project, model.root_folder)
  85. # iterate over media files
  86. index = 0
  87. for file in media_files:
  88. # remove old predictions
  89. file.remove_predictions()
  90. # create new predictions
  91. pipeline.execute(storage, file)
  92. # commit changes and yield progress
  93. db.session.commit()
  94. yield index / length, notifications
  95. index += 1
  96. finally:
  97. if pipeline is not None:
  98. pipelines.free_instance(model.root_folder)
  99. @staticmethod
  100. def progress(progress: float, notifications: NotificationList):
  101. """
  102. fire notifications from the correct thread
  103. :param progress: [0, 1]
  104. :param notifications: Notificationlist
  105. :return: progress
  106. """
  107. notifications.fire()
  108. return progress