PredictModel.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. from typing import List
  2. from typing import Union
  3. from flask import abort
  4. from flask import make_response
  5. from flask import request
  6. from flask.views import View
  7. from pycs import db
  8. from pycs.database.Project import Project
  9. from pycs.database.Result import Result
  10. from pycs.frontend.notifications.NotificationList import NotificationList
  11. from pycs.frontend.notifications.NotificationManager import NotificationManager
  12. from pycs.interfaces.MediaFile import MediaFile
  13. from pycs.interfaces.MediaLabel import MediaLabel
  14. from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
  15. from pycs.interfaces.MediaStorage import MediaStorage
  16. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  17. from pycs.jobs.JobRunner import JobRunner
  18. from pycs.util.PipelineCache import PipelineCache
  19. class PredictModel(View):
  20. """
  21. load a model and create predictions
  22. """
  23. # pylint: disable=arguments-differ
  24. methods = ['POST']
  25. def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  26. # pylint: disable=invalid-name
  27. self.nm = nm
  28. self.jobs = jobs
  29. self.pipelines = pipelines
  30. def dispatch_request(self, user: str, project_id):
  31. # pylint: disable=unused-argument
  32. project = Project.get_or_404(project_id)
  33. # extract request data
  34. data = request.get_json(force=True)
  35. predict = data.get('predict')
  36. if predict is None:
  37. abort(400, "predict argument is missing")
  38. if predict not in ['all', 'new']:
  39. abort(400, "predict must be either 'all' or 'new'")
  40. # create job
  41. try:
  42. notifications = NotificationList(self.nm)
  43. self.jobs.run(project,
  44. 'Model Interaction',
  45. f'{project.name} (create predictions)',
  46. f'{project.id}/model-interaction',
  47. PredictModel.load_and_predict,
  48. self.pipelines, notifications,
  49. project.id, predict,
  50. progress=self.progress)
  51. except JobGroupBusyException:
  52. abort(400, "Model prediction is already running")
  53. return make_response()
  54. @staticmethod
  55. def load_and_predict(pipelines: PipelineCache,
  56. notifications: NotificationList,
  57. project_id: int, file_filter: Union[str, List[int]]):
  58. """
  59. load the pipeline and call the execute function
  60. :param database: database object
  61. :param pipelines: pipeline cache
  62. :param notifications: notification object
  63. :param project_id: project id
  64. :param file_filter: list of file ids or 'new' / 'all'
  65. :return:
  66. """
  67. pipeline = None
  68. # create new database instance
  69. project = Project.query.get(project_id)
  70. model_root = project.model.root_folder
  71. storage = MediaStorage(project_id, notifications)
  72. # create a list of MediaFile
  73. if isinstance(file_filter, str):
  74. if file_filter == 'new':
  75. files = project.files_without_results()
  76. length = project.count_files_without_results()
  77. else:
  78. files = project.files.all()
  79. length = project.files.count()
  80. else:
  81. files = [project.file(identifier) for identifier in file_filter]
  82. length = len(files)
  83. media_files = map(lambda f: MediaFile(f, notifications), files)
  84. # load pipeline
  85. try:
  86. pipeline = pipelines.load_from_root_folder(project_id, model_root)
  87. # iterate over media files
  88. index = 0
  89. for file in media_files:
  90. # remove old predictions
  91. file.remove_predictions()
  92. # create new predictions
  93. pipeline.execute(storage, file)
  94. # commit changes and yield progress
  95. db.session.commit()
  96. yield index / length, notifications
  97. index += 1
  98. finally:
  99. if pipeline is not None:
  100. pipelines.free_instance(model_root)
  101. @staticmethod
  102. def load_and_pure_inference(pipelines: PipelineCache,
  103. notifications: NotificationList,
  104. notification_manager: NotificationManager,
  105. project_id: int, file_filter: List[int],
  106. bbox_id_filter: dict[int, List[int]], user: str):
  107. """
  108. load the pipeline and call the execute function
  109. :param database: database object
  110. :param pipelines: pipeline cache
  111. :param notifications: notification object
  112. :param notification_manager: notification manager
  113. :param project_id: project id
  114. :param file_filter: list of file ids
  115. :param bbox_id_filter: dict of file id and list of bbox_ids to classify
  116. :param user: username of the user asking to predict the bounding box
  117. :return:
  118. """
  119. pipeline = None
  120. # create new database instance
  121. project = Project.query.get(project_id)
  122. model_root = project.model.root_folder
  123. storage = MediaStorage(project_id, notifications)
  124. # load pipeline
  125. try:
  126. pipeline = pipelines.load_from_root_folder(project_id, model_root)
  127. # iterate over media files
  128. index = 0
  129. length = len(file_filter)
  130. for file_id in file_filter:
  131. file = project.file(file_id)
  132. file = MediaFile(file, notifications)
  133. bounding_boxes = [MediaBoundingBox(Result.get_or_404(bbox_id))
  134. for bbox_id in bbox_id_filter[file_id]]
  135. # Perform inference.
  136. bbox_labels = pipeline.pure_inference(storage, file, bounding_boxes)
  137. # Add the labels determined in the inference process.
  138. # for i, result in enumerate(result_filter[file_id]):
  139. # bbox_label = bbox_labels[i]
  140. # if isinstance(bbox_label, MediaLabel):
  141. # result.label_id = bbox_label.identifier
  142. # result.set_origin('user', commit=True)
  143. for i, bbox_id in enumerate(bbox_id_filter[file_id]):
  144. result = Result.get_or_404(bbox_id)
  145. result.set_label(bbox_labels[i].identifier, commit=True)
  146. result.set_origin('user', origin_user=user, commit=True)
  147. notifications.add(notification_manager.edit_result, result)
  148. # commit changes and yield progress
  149. db.session.commit()
  150. yield index / length, notifications
  151. index += 1
  152. finally:
  153. if pipeline is not None:
  154. pipelines.free_instance(model_root)
  155. @staticmethod
  156. def progress(progress: float, notifications: NotificationList):
  157. """
  158. fire notifications from the correct thread
  159. :param progress: [0, 1]
  160. :param notifications: Notificationlist
  161. :return: progress
  162. """
  163. notifications.fire()
  164. return progress