PredictModel.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from typing import Any
  2. from flask import make_response, request, abort
  3. from flask.views import View
  4. from pycs.database.Database import Database
  5. from pycs.frontend.notifications.NotificationList import NotificationList
  6. from pycs.frontend.notifications.NotificationManager import NotificationManager
  7. from pycs.interfaces.MediaFile import MediaFile
  8. from pycs.interfaces.MediaStorage import MediaStorage
  9. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  10. from pycs.jobs.JobRunner import JobRunner
  11. from pycs.util.PipelineUtil import load_from_root_folder as load_pipeline
  12. class PredictModel(View):
  13. """
  14. load a model and create predictions
  15. """
  16. # pylint: disable=arguments-differ
  17. methods = ['POST']
  18. def __init__(self, db: Database, nm: NotificationManager, jobs: JobRunner):
  19. # pylint: disable=invalid-name
  20. self.db = db
  21. self.nm = nm
  22. self.jobs = jobs
  23. def dispatch_request(self, project_id):
  24. # extract request data
  25. data = request.get_json(force=True)
  26. if 'predict' not in data or data['predict'] not in ['all', 'new']:
  27. return abort(400)
  28. # find project
  29. project = self.db.project(project_id)
  30. if project is None:
  31. return abort(404)
  32. # create job
  33. try:
  34. notifications = NotificationList(self.nm)
  35. self.jobs.run(project,
  36. 'Model Interaction',
  37. f'{project.name} (create predictions)',
  38. f'{project.name}/model-interaction',
  39. self.load_and_predict,
  40. self.db, project.identifier, data['predict'], notifications,
  41. progress=self.progress)
  42. except JobGroupBusyException:
  43. return abort(400)
  44. return make_response()
  45. @staticmethod
  46. def load_and_predict(database: Database, project_id: int, file_filter: Any,
  47. notifications: NotificationList):
  48. db = None
  49. pipeline = None
  50. # create new database instance
  51. try:
  52. db = database.copy()
  53. project = db.project(project_id)
  54. model = project.model()
  55. storage = MediaStorage(db, project_id, notifications)
  56. # create a list of MediaFile
  57. if isinstance(file_filter, str):
  58. if file_filter == 'new':
  59. length = project.count_files_without_results()
  60. files = map(lambda f: MediaFile(f, notifications),
  61. project.files_without_results())
  62. else:
  63. length = project.count_files()
  64. files = map(lambda f: MediaFile(f, notifications),
  65. project.files())
  66. else:
  67. files = map(lambda f: MediaFile(project.file(f.identifier), notifications),
  68. file_filter)
  69. length = len(file_filter)
  70. # load pipeline
  71. try:
  72. pipeline = load_pipeline(model.root_folder)
  73. # iterate over files
  74. index = 0
  75. for file in files:
  76. # remove old predictions
  77. file.remove_predictions()
  78. # create new predictions
  79. pipeline.execute(storage, file)
  80. # commit changes and yield progress
  81. db.commit()
  82. yield index / length, notifications
  83. index += 1
  84. finally:
  85. if pipeline is not None:
  86. pipeline.close()
  87. finally:
  88. if db is not None:
  89. db.close()
  90. @staticmethod
  91. def progress(progress: float, notifications: NotificationList):
  92. notifications.fire()
  93. return progress