PredictModel.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. from typing import List
  2. from typing import Union
  3. from flask import abort
  4. from flask import make_response
  5. from flask import request
  6. from flask.views import View
  7. from pycs import db
  8. from pycs.database.Project import Project
  9. from pycs.database.Result import Result
  10. from pycs.frontend.notifications.NotificationList import NotificationList
  11. from pycs.frontend.notifications.NotificationManager import NotificationManager
  12. from pycs.interfaces.MediaFile import MediaFile
  13. from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
  14. from pycs.interfaces.MediaStorage import MediaStorage
  15. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  16. from pycs.jobs.JobRunner import JobRunner
  17. from pycs.util.PipelineCache import PipelineCache
  18. class PredictModel(View):
  19. """
  20. load a model and create predictions
  21. """
  22. # pylint: disable=arguments-differ
  23. methods = ['POST']
  24. def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  25. # pylint: disable=invalid-name
  26. self.nm = nm
  27. self.jobs = jobs
  28. self.pipelines = pipelines
  29. def dispatch_request(self, user: str, project_id):
  30. # pylint: disable=unused-argument
  31. project = Project.get_or_404(project_id)
  32. # extract request data
  33. data = request.get_json(force=True)
  34. predict = data.get('predict')
  35. if predict is None:
  36. abort(400, "predict argument is missing")
  37. if predict not in ['all', 'new']:
  38. abort(400, "predict must be either 'all' or 'new'")
  39. # create job
  40. try:
  41. notifications = NotificationList(self.nm)
  42. self.jobs.run(project,
  43. 'Model Interaction',
  44. f'{project.name} (create predictions)',
  45. f'{project.id}/model-interaction',
  46. PredictModel.load_and_predict,
  47. self.pipelines, notifications,
  48. project.id, predict,
  49. progress=self.progress)
  50. except JobGroupBusyException:
  51. abort(400, "Model prediction is already running")
  52. return make_response()
  53. @staticmethod
  54. def load_and_predict(pipelines: PipelineCache,
  55. notifications: NotificationList,
  56. project_id: int, file_filter: Union[str, List[int]]):
  57. """
  58. load the pipeline and call the execute function
  59. :param database: database object
  60. :param pipelines: pipeline cache
  61. :param notifications: notification object
  62. :param project_id: project id
  63. :param file_filter: list of file ids or 'new' / 'all'
  64. :return:
  65. """
  66. pipeline = None
  67. # create new database instance
  68. project = Project.query.get(project_id)
  69. model_root = project.model.root_folder
  70. storage = MediaStorage(project_id, notifications)
  71. # create a list of MediaFile
  72. if isinstance(file_filter, str):
  73. if file_filter == 'new':
  74. files = project.files_without_results()
  75. length = project.count_files_without_results()
  76. else:
  77. files = project.files.all()
  78. length = project.files.count()
  79. else:
  80. files = [project.file(identifier) for identifier in file_filter]
  81. length = len(files)
  82. media_files = map(lambda f: MediaFile(f, notifications), files)
  83. # load pipeline
  84. try:
  85. pipeline = pipelines.load_from_root_folder(project_id, model_root)
  86. # iterate over media files
  87. index = 0
  88. for file in media_files:
  89. # remove old predictions
  90. file.remove_predictions()
  91. # create new predictions
  92. pipeline.execute(storage, file)
  93. # commit changes and yield progress
  94. db.session.commit()
  95. yield index / length, notifications
  96. index += 1
  97. finally:
  98. if pipeline is not None:
  99. pipelines.free_instance(model_root)
  100. @staticmethod
  101. def load_and_pure_inference(pipelines: PipelineCache,
  102. notifications: NotificationList,
  103. notification_manager: NotificationManager,
  104. project_id: int, file_filter: List[int],
  105. bbox_id_filter: dict[int, List[int]], user: str):
  106. """
  107. load the pipeline and call the execute function
  108. :param database: database object
  109. :param pipelines: pipeline cache
  110. :param notifications: notification object
  111. :param notification_manager: notification manager
  112. :param project_id: project id
  113. :param file_filter: list of file ids
  114. :param bbox_id_filter: dict of file id and list of bbox_ids to classify
  115. :param user: username of the user asking to predict the bounding box
  116. :return:
  117. """
  118. pipeline = None
  119. # create new database instance
  120. project = Project.query.get(project_id)
  121. model_root = project.model.root_folder
  122. storage = MediaStorage(project_id, notifications)
  123. # load pipeline
  124. try:
  125. pipeline = pipelines.load_from_root_folder(project_id, model_root)
  126. # iterate over media files
  127. index = 0
  128. length = len(file_filter)
  129. for file_id in file_filter:
  130. file = project.file(file_id)
  131. file = MediaFile(file, notifications)
  132. bounding_boxes = [MediaBoundingBox(Result.get_or_404(bbox_id))
  133. for bbox_id in bbox_id_filter[file_id]]
  134. # Perform inference.
  135. bbox_labels = pipeline.pure_inference(storage, file, bounding_boxes)
  136. # Add the labels determined in the inference process.
  137. # for i, result in enumerate(result_filter[file_id]):
  138. # bbox_label = bbox_labels[i]
  139. # if isinstance(bbox_label, MediaLabel):
  140. # result.label_id = bbox_label.identifier
  141. # result.set_origin('user', commit=True)
  142. for i, bbox_id in enumerate(bbox_id_filter[file_id]):
  143. result = Result.get_or_404(bbox_id)
  144. result.set_label(bbox_labels[i].identifier, commit=True)
  145. result.set_origin('user', origin_user=user, commit=True)
  146. notifications.add(notification_manager.edit_result, result)
  147. # commit changes and yield progress
  148. db.session.commit()
  149. yield index / length, notifications
  150. index += 1
  151. finally:
  152. if pipeline is not None:
  153. pipelines.free_instance(model_root)
  154. @staticmethod
  155. def progress(progress: float, notifications: NotificationList):
  156. """
  157. fire notifications from the correct thread
  158. :param progress: [0, 1]
  159. :param notifications: Notificationlist
  160. :return: progress
  161. """
  162. notifications.fire()
  163. return progress