PredictModel.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from typing import Union, List
  2. from flask import make_response, request, abort
  3. from flask.views import View
  4. from pycs.database.Database import Database
  5. from pycs.database.File import File
  6. from pycs.frontend.notifications.NotificationList import NotificationList
  7. from pycs.frontend.notifications.NotificationManager import NotificationManager
  8. from pycs.interfaces.MediaFile import MediaFile
  9. from pycs.interfaces.MediaStorage import MediaStorage
  10. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  11. from pycs.jobs.JobRunner import JobRunner
  12. from pycs.util.PipelineCache import PipelineCache
  13. class PredictModel(View):
  14. """
  15. load a model and create predictions
  16. """
  17. # pylint: disable=arguments-differ
  18. methods = ['POST']
  19. def __init__(self,
  20. db: Database, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  21. # pylint: disable=invalid-name
  22. self.db = db
  23. self.nm = nm
  24. self.jobs = jobs
  25. self.pipelines = pipelines
  26. def dispatch_request(self, project_id):
  27. # extract request data
  28. data = request.get_json(force=True)
  29. if 'predict' not in data or data['predict'] not in ['all', 'new']:
  30. return abort(400)
  31. # find project
  32. project = self.db.project(project_id)
  33. if project is None:
  34. return abort(404)
  35. # create job
  36. try:
  37. notifications = NotificationList(self.nm)
  38. self.jobs.run(project,
  39. 'Model Interaction',
  40. f'{project.name} (create predictions)',
  41. f'{project.name}/model-interaction',
  42. self.load_and_predict,
  43. self.db, self.pipelines, notifications,
  44. project.identifier, data['predict'],
  45. progress=self.progress)
  46. except JobGroupBusyException:
  47. return abort(400)
  48. return make_response()
  49. @staticmethod
  50. def load_and_predict(database: Database, pipelines: PipelineCache,
  51. notifications: NotificationList,
  52. project_id: int, file_filter: Union[str, List[File]]):
  53. """
  54. load the pipeline and call the execute function
  55. :param database: database object
  56. :param pipelines: pipeline cache
  57. :param notifications: notification object
  58. :param project_id: project id
  59. :param file_filter: list of files or 'new' / 'all'
  60. :return:
  61. """
  62. database_copy = None
  63. pipeline = None
  64. # create new database instance
  65. try:
  66. database_copy = database.copy()
  67. project = database_copy.project(project_id)
  68. model = project.model()
  69. storage = MediaStorage(database_copy, project_id, notifications)
  70. # create a list of MediaFile
  71. if isinstance(file_filter, str):
  72. if file_filter == 'new':
  73. length = project.count_files_without_results()
  74. files = map(lambda f: MediaFile(f, notifications),
  75. project.files_without_results())
  76. else:
  77. length = project.count_files()
  78. files = map(lambda f: MediaFile(f, notifications),
  79. project.files())
  80. else:
  81. files = map(lambda f: MediaFile(project.file(f.identifier), notifications),
  82. file_filter)
  83. length = len(file_filter)
  84. # load pipeline
  85. try:
  86. pipeline = pipelines.load_from_root_folder(project, model.root_folder)
  87. # iterate over files
  88. index = 0
  89. for file in files:
  90. # remove old predictions
  91. file.remove_predictions()
  92. # create new predictions
  93. pipeline.execute(storage, file)
  94. # commit changes and yield progress
  95. database_copy.commit()
  96. yield index / length, notifications
  97. index += 1
  98. finally:
  99. if pipeline is not None:
  100. pipelines.free_instance(model.root_folder)
  101. finally:
  102. if database_copy is not None:
  103. database_copy.close()
  104. @staticmethod
  105. def progress(progress: float, notifications: NotificationList):
  106. """
  107. fire notifications from the correct thread
  108. :param progress: [0, 1]
  109. :param notifications: Notificationlist
  110. :return: progress
  111. """
  112. notifications.fire()
  113. return progress