FitModel.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from flask import make_response, request, abort
  2. from flask.views import View
  3. from pycs.database.Database import Database
  4. from pycs.interfaces.MediaStorage import MediaStorage
  5. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  6. from pycs.jobs.JobRunner import JobRunner
  7. from pycs.util.PipelineCache import PipelineCache
  8. class FitModel(View):
  9. """
  10. use annotated data to fit a model
  11. """
  12. # pylint: disable=arguments-differ
  13. methods = ['POST']
  14. def __init__(self, db: Database, jobs: JobRunner, pipelines: PipelineCache):
  15. # pylint: disable=invalid-name
  16. self.db = db
  17. self.jobs = jobs
  18. self.pipelines = pipelines
  19. def dispatch_request(self, project_id):
  20. # extract request data
  21. data = request.get_json(force=True)
  22. if 'fit' not in data or data['fit'] is not True:
  23. return abort(400)
  24. # find project
  25. project = self.db.project(project_id)
  26. if project is None:
  27. return abort(404)
  28. # create job
  29. try:
  30. self.jobs.run(project,
  31. 'Model Interaction',
  32. f'{project.name} (fit model with new data)',
  33. f'{project.name}/model-interaction',
  34. self.load_and_fit, self.db, project.identifier)
  35. except JobGroupBusyException:
  36. return abort(400)
  37. return make_response()
  38. @staticmethod
  39. def load_and_fit(database: Database, pipelines: PipelineCache, project_id: int):
  40. """
  41. load the pipeline and call the fit function
  42. :param database: database object
  43. :param pipelines: pipeline cache
  44. :param project_id: project id
  45. """
  46. database_copy = None
  47. pipeline = None
  48. # create new database instance
  49. try:
  50. database_copy = database.copy()
  51. project = database_copy.project(project_id)
  52. model = project.model()
  53. storage = MediaStorage(database_copy, project_id)
  54. # load pipeline
  55. try:
  56. pipeline = pipelines.load_from_root_folder(project, model.root_folder)
  57. yield from pipeline.fit(storage)
  58. except TypeError:
  59. pass
  60. finally:
  61. if pipeline is not None:
  62. pipelines.free_instance(model.root_folder)
  63. finally:
  64. if database_copy is not None:
  65. database_copy.close()