6
0

PredictModel.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from typing import Any
  2. from flask import abort
  3. from flask import make_response
  4. from flask import request
  5. from flask.views import View
  6. from pycs import app
  7. from pycs import db
  8. from pycs.database.Database import Database
  9. from pycs.database.Project import Project
  10. from pycs.frontend.notifications.NotificationList import NotificationList
  11. from pycs.frontend.notifications.NotificationManager import NotificationManager
  12. from pycs.interfaces.MediaFile import MediaFile
  13. from pycs.interfaces.MediaStorage import MediaStorage
  14. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  15. from pycs.jobs.JobRunner import JobRunner
  16. from pycs.util.PipelineCache import PipelineCache
  17. class PredictModel(View):
  18. """
  19. load a model and create predictions
  20. """
  21. # pylint: disable=arguments-differ
  22. methods = ['POST']
  23. def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  24. # pylint: disable=invalid-name
  25. self.nm = nm
  26. self.jobs = jobs
  27. self.pipelines = pipelines
  28. def dispatch_request(self, project_id):
  29. # extract request data
  30. data = request.get_json(force=True)
  31. if data.get('predict') not in ['all', 'new']:
  32. return abort(400)
  33. # find project
  34. project = Project.query.get(project_id)
  35. if project is None:
  36. return abort(404)
  37. # create job
  38. try:
  39. notifications = NotificationList(self.nm)
  40. self.jobs.run(project,
  41. 'Model Interaction',
  42. f'{project.name} (create predictions)',
  43. f'{project.name}/model-interaction',
  44. PredictModel.load_and_predict,
  45. self.pipelines, notifications,
  46. project.id, data['predict'],
  47. progress=self.progress)
  48. except JobGroupBusyException:
  49. return abort(400)
  50. return make_response()
  51. @staticmethod
  52. def load_and_predict(pipelines: PipelineCache,
  53. notifications: NotificationList, project_id: int, file_filter: Any):
  54. pipeline = None
  55. # create new database instance
  56. try:
  57. project = Project.query.get(project_id)
  58. model = project.model
  59. storage = MediaStorage(project_id, notifications)
  60. # create a list of MediaFile
  61. if isinstance(file_filter, str):
  62. if file_filter == 'new':
  63. length = project.count_files_without_results()
  64. files = map(lambda f: MediaFile(f, notifications),
  65. project.files_without_results())
  66. else:
  67. length = project.count_files()
  68. files = map(lambda f: MediaFile(f, notifications),
  69. project.files.all())
  70. else:
  71. files = map(lambda f: MediaFile(project.file(f.id), notifications),
  72. file_filter)
  73. length = len(file_filter)
  74. # load pipeline
  75. try:
  76. pipeline = pipelines.load_from_root_folder(project, model.root_folder)
  77. # iterate over files
  78. index = 0
  79. for file in files:
  80. # remove old predictions
  81. file.remove_predictions()
  82. # create new predictions
  83. pipeline.execute(storage, file)
  84. # commit changes and yield progress
  85. db.session.commit()
  86. yield index / length, notifications
  87. index += 1
  88. except Exception as e:
  89. import traceback
  90. traceback.print_exc()
  91. app.logger.warning(f"Pipeline Error #2: {e}")
  92. finally:
  93. if pipeline is not None:
  94. pipelines.free_instance(model.root_folder)
  95. except Exception as e:
  96. import traceback
  97. traceback.print_exc()
  98. app.logger.warning(f"Pipeline Error #1: {e}")
  99. @staticmethod
  100. def progress(progress: float, notifications: NotificationList):
  101. notifications.fire()
  102. return progress