PredictModel.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from typing import Any
  2. from flask import make_response, request, abort
  3. from flask.views import View
  4. from pycs import app
  5. from pycs.database.Database import Database
  6. from pycs.database.Project import Project
  7. from pycs.frontend.notifications.NotificationList import NotificationList
  8. from pycs.frontend.notifications.NotificationManager import NotificationManager
  9. from pycs.interfaces.MediaFile import MediaFile
  10. from pycs.interfaces.MediaStorage import MediaStorage
  11. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  12. from pycs.jobs.JobRunner import JobRunner
  13. from pycs.util.PipelineCache import PipelineCache
  14. class PredictModel(View):
  15. """
  16. load a model and create predictions
  17. """
  18. # pylint: disable=arguments-differ
  19. methods = ['POST']
  20. def __init__(self,
  21. db: Database, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  22. # pylint: disable=invalid-name
  23. self.db = db
  24. self.nm = nm
  25. self.jobs = jobs
  26. self.pipelines = pipelines
  27. def dispatch_request(self, project_id):
  28. # extract request data
  29. data = request.get_json(force=True)
  30. if data.get('predict') not in ['all', 'new']:
  31. return abort(400)
  32. # find project
  33. project = Project.query.get(project_id)
  34. if project is None:
  35. return abort(404)
  36. # create job
  37. try:
  38. notifications = NotificationList(self.nm)
  39. self.jobs.run(project,
  40. 'Model Interaction',
  41. f'{project.name} (create predictions)',
  42. f'{project.name}/model-interaction',
  43. self.load_and_predict,
  44. self.db, self.pipelines, notifications,
  45. project.identifier, data['predict'],
  46. progress=self.progress)
  47. except JobGroupBusyException:
  48. return abort(400)
  49. return make_response()
  50. @staticmethod
  51. def load_and_predict(database: Database, pipelines: PipelineCache,
  52. notifications: NotificationList, project_id: int, file_filter: Any):
  53. db = None
  54. pipeline = None
  55. # create new database instance
  56. try:
  57. db = database.copy()
  58. project = Project.query.get(project_id)
  59. model = project.model
  60. storage = MediaStorage(project_id, notifications)
  61. # create a list of MediaFile
  62. if isinstance(file_filter, str):
  63. if file_filter == 'new':
  64. length = project.count_files_without_results()
  65. files = map(lambda f: MediaFile(f, notifications),
  66. project.files_without_results())
  67. else:
  68. length = project.count_files()
  69. files = map(lambda f: MediaFile(f, notifications),
  70. project.files.all())
  71. else:
  72. files = map(lambda f: MediaFile(project.file(f.identifier), notifications),
  73. file_filter)
  74. length = len(file_filter)
  75. # load pipeline
  76. try:
  77. pipeline = pipelines.load_from_root_folder(project, model.root_folder)
  78. # iterate over files
  79. index = 0
  80. for file in files:
  81. # remove old predictions
  82. file.remove_predictions()
  83. # create new predictions
  84. pipeline.execute(storage, file)
  85. # commit changes and yield progress
  86. db.commit()
  87. yield index / length, notifications
  88. index += 1
  89. except Exception as e:
  90. import traceback
  91. traceback.print_exc()
  92. app.logger.warning(f"Pipeline Error #2: {e}")
  93. finally:
  94. if pipeline is not None:
  95. pipelines.free_instance(model.root_folder)
  96. except Exception as e:
  97. import traceback
  98. traceback.print_exc()
  99. app.logger.warning(f"Pipeline Error #1: {e}")
  100. finally:
  101. if db is not None:
  102. db.close()
  103. @staticmethod
  104. def progress(progress: float, notifications: NotificationList):
  105. notifications.fire()
  106. return progress