PredictModel.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from typing import List
  2. from typing import Union
  3. from flask import abort
  4. from flask import make_response
  5. from flask import request
  6. from flask.views import View
  7. from pycs import db
  8. from pycs.database.Project import Project
  9. from pycs.frontend.notifications.NotificationList import NotificationList
  10. from pycs.frontend.notifications.NotificationManager import NotificationManager
  11. from pycs.interfaces.MediaFile import MediaFile
  12. from pycs.interfaces.MediaStorage import MediaStorage
  13. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  14. from pycs.jobs.JobRunner import JobRunner
  15. from pycs.util.PipelineCache import PipelineCache
  16. class PredictModel(View):
  17. """
  18. load a model and create predictions
  19. """
  20. # pylint: disable=arguments-differ
  21. methods = ['POST']
  22. def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  23. # pylint: disable=invalid-name
  24. self.nm = nm
  25. self.jobs = jobs
  26. self.pipelines = pipelines
  27. def dispatch_request(self, project_id):
  28. # extract request data
  29. data = request.get_json(force=True)
  30. predict = data.get('predict')
  31. if predict is None:
  32. abort(400, "predict argument is missing")
  33. if predict not in ['all', 'new']:
  34. abort(400, "predict must be either 'all' or 'new'")
  35. # find project
  36. project = Project.get_or_404(project_id)
  37. # create job
  38. try:
  39. notifications = NotificationList(self.nm)
  40. self.jobs.run(project,
  41. 'Model Interaction',
  42. f'{project.name} (create predictions)',
  43. f'{project.id}/model-interaction',
  44. PredictModel.load_and_predict,
  45. self.pipelines, notifications,
  46. project.id, predict,
  47. progress=self.progress)
  48. except JobGroupBusyException:
  49. abort(400, "Model prediction is already running")
  50. return make_response()
  51. @staticmethod
  52. def load_and_predict(pipelines: PipelineCache,
  53. notifications: NotificationList,
  54. project_id: int, file_filter: Union[str, List[int]]):
  55. """
  56. load the pipeline and call the execute function
  57. :param database: database object
  58. :param pipelines: pipeline cache
  59. :param notifications: notification object
  60. :param project_id: project id
  61. :param file_filter: list of file ids or 'new' / 'all'
  62. :return:
  63. """
  64. pipeline = None
  65. # create new database instance
  66. project = Project.query.get(project_id)
  67. model_root = project.model.root_folder
  68. storage = MediaStorage(project_id, notifications)
  69. # create a list of MediaFile
  70. if isinstance(file_filter, str):
  71. if file_filter == 'new':
  72. files = project.files_without_results()
  73. length = project.count_files_without_results()
  74. else:
  75. files = project.files.all()
  76. length = project.files.count()
  77. else:
  78. files = [project.file(identifier) for identifier in file_filter]
  79. length = len(files)
  80. media_files = map(lambda f: MediaFile(f, notifications), files)
  81. # load pipeline
  82. try:
  83. pipeline = pipelines.load_from_root_folder(project, model_root)
  84. # iterate over media files
  85. index = 0
  86. for file in media_files:
  87. # remove old predictions
  88. file.remove_predictions()
  89. # create new predictions
  90. pipeline.execute(storage, file)
  91. # commit changes and yield progress
  92. db.session.commit()
  93. yield index / length, notifications
  94. index += 1
  95. finally:
  96. if pipeline is not None:
  97. pipelines.free_instance(model_root)
  98. @staticmethod
  99. def progress(progress: float, notifications: NotificationList):
  100. """
  101. fire notifications from the correct thread
  102. :param progress: [0, 1]
  103. :param notifications: Notificationlist
  104. :return: progress
  105. """
  106. notifications.fire()
  107. return progress