6
0

FitModel.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from flask import abort
  2. from flask import make_response
  3. from flask import request
  4. from flask.views import View
  5. from pycs.database.Project import Project
  6. from pycs.interfaces.MediaStorage import MediaStorage
  7. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  8. from pycs.jobs.JobRunner import JobRunner
  9. from pycs.util.PipelineCache import PipelineCache
  10. class FitModel(View):
  11. """
  12. use annotated data to fit a model
  13. """
  14. # pylint: disable=arguments-differ
  15. methods = ['POST']
  16. def __init__(self, jobs: JobRunner, pipelines: PipelineCache):
  17. # pylint: disable=invalid-name
  18. self.jobs = jobs
  19. self.pipelines = pipelines
  20. def dispatch_request(self, user: str, project_id):
  21. # pylint: disable=unused-argument
  22. project = Project.get_or_404(project_id)
  23. # extract request data
  24. data = request.get_json(force=True)
  25. if not data.get('fit', False):
  26. abort(400, "fit flag is missing")
  27. # create job
  28. try:
  29. self.jobs.run(project,
  30. 'Model Interaction',
  31. f'{project.name} (fit model with new data)',
  32. f'{project.name}/model-interaction',
  33. FitModel.load_and_fit,
  34. self.pipelines,
  35. project.id)
  36. except JobGroupBusyException:
  37. return abort(400, "Model fitting already running")
  38. return make_response()
  39. @staticmethod
  40. def load_and_fit(pipelines: PipelineCache, project_id: int):
  41. """
  42. load the pipeline and call the fit function
  43. :param pipelines: pipeline cache
  44. :param project_id: project id
  45. """
  46. pipeline = None
  47. # create new database instance
  48. project = Project.query.get(project_id)
  49. model = project.model
  50. storage = MediaStorage(project_id)
  51. # load pipeline
  52. try:
  53. pipeline = pipelines.load_from_root_folder(project_id, model.root_folder)
  54. yield from pipeline.fit(storage)
  55. except TypeError:
  56. pass
  57. finally:
  58. if pipeline is not None:
  59. pipelines.free_instance(model.root_folder)