6
0

PredictModel.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from typing import Any
  2. from flask import make_response, request, abort
  3. from flask.views import View
  4. from pycs.database.Database import Database
  5. from pycs.frontend.notifications.NotificationList import NotificationList
  6. from pycs.frontend.notifications.NotificationManager import NotificationManager
  7. from pycs.interfaces.MediaFile import MediaFile
  8. from pycs.interfaces.MediaStorage import MediaStorage
  9. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  10. from pycs.jobs.JobRunner import JobRunner
  11. from pycs.util.PipelineCache import PipelineCache
  12. class PredictModel(View):
  13. """
  14. load a model and create predictions
  15. """
  16. # pylint: disable=arguments-differ
  17. methods = ['POST']
  18. def __init__(self,
  19. db: Database, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
  20. # pylint: disable=invalid-name
  21. self.db = db
  22. self.nm = nm
  23. self.jobs = jobs
  24. self.pipelines = pipelines
  25. def dispatch_request(self, project_id):
  26. # extract request data
  27. data = request.get_json(force=True)
  28. if 'predict' not in data or data['predict'] not in ['all', 'new']:
  29. return abort(400)
  30. # find project
  31. project = self.db.project(project_id)
  32. if project is None:
  33. return abort(404)
  34. # create job
  35. try:
  36. notifications = NotificationList(self.nm)
  37. self.jobs.run(project,
  38. 'Model Interaction',
  39. f'{project.name} (create predictions)',
  40. f'{project.name}/model-interaction',
  41. self.load_and_predict,
  42. self.db, self.pipelines, notifications,
  43. project.identifier, data['predict'],
  44. progress=self.progress)
  45. except JobGroupBusyException:
  46. return abort(400)
  47. return make_response()
  48. @staticmethod
  49. def load_and_predict(database: Database, pipelines: PipelineCache,
  50. notifications: NotificationList, project_id: int, file_filter: Any):
  51. db = None
  52. pipeline = None
  53. # create new database instance
  54. try:
  55. db = database.copy()
  56. project = db.project(project_id)
  57. model = project.model()
  58. storage = MediaStorage(db, project_id, notifications)
  59. # create a list of MediaFile
  60. if isinstance(file_filter, str):
  61. if file_filter == 'new':
  62. length = project.count_files_without_results()
  63. files = map(lambda f: MediaFile(f, notifications),
  64. project.files_without_results())
  65. else:
  66. length = project.count_files()
  67. files = map(lambda f: MediaFile(f, notifications),
  68. project.files())
  69. else:
  70. files = map(lambda f: MediaFile(project.file(f.identifier), notifications),
  71. file_filter)
  72. length = len(file_filter)
  73. # load pipeline
  74. try:
  75. pipeline = pipelines.load_from_root_folder(project, model.root_folder)
  76. # iterate over files
  77. index = 0
  78. for file in files:
  79. # remove old predictions
  80. file.remove_predictions()
  81. # create new predictions
  82. pipeline.execute(storage, file)
  83. # commit changes and yield progress
  84. db.commit()
  85. yield index / length, notifications
  86. index += 1
  87. finally:
  88. if pipeline is not None:
  89. pipelines.free_instance(model.root_folder)
  90. finally:
  91. if db is not None:
  92. db.close()
  93. @staticmethod
  94. def progress(progress: float, notifications: NotificationList):
  95. notifications.fire()
  96. return progress