FitModel.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from flask import make_response, request, abort
  2. from flask.views import View
  3. from pycs.database.Project import Project
  4. from pycs.interfaces.MediaStorage import MediaStorage
  5. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  6. from pycs.jobs.JobRunner import JobRunner
  7. from pycs.util.PipelineCache import PipelineCache
  8. class FitModel(View):
  9. """
  10. use annotated data to fit a model
  11. """
  12. # pylint: disable=arguments-differ
  13. methods = ['POST']
  14. def __init__(self, jobs: JobRunner, pipelines: PipelineCache):
  15. # pylint: disable=invalid-name
  16. self.jobs = jobs
  17. self.pipelines = pipelines
  18. def dispatch_request(self, project_id):
  19. # extract request data
  20. data = request.get_json(force=True)
  21. if not data.get('fit', False):
  22. abort(400, "fit flag is missing")
  23. # find project
  24. project = Project.get_or_404(project_id)
  25. # create job
  26. try:
  27. self.jobs.run(project,
  28. 'Model Interaction',
  29. f'{project.name} (fit model with new data)',
  30. f'{project.name}/model-interaction',
  31. FitModel.load_and_fit, project.id)
  32. except JobGroupBusyException:
  33. return abort(400, "Model fitting already running")
  34. return make_response()
  35. @staticmethod
  36. def load_and_fit(pipelines: PipelineCache, project_id: int):
  37. """
  38. load the pipeline and call the fit function
  39. :param pipelines: pipeline cache
  40. :param project_id: project id
  41. """
  42. database_copy = None
  43. pipeline = None
  44. # create new database instance
  45. try:
  46. database_copy = database.copy()
  47. project = Project.query.get(project_id)
  48. model = project.model
  49. storage = MediaStorage(database_copy, project_id)
  50. # load pipeline
  51. try:
  52. pipeline = pipelines.load_from_root_folder(project, model.root_folder)
  53. yield from pipeline.fit(storage)
  54. except TypeError:
  55. pass
  56. finally:
  57. if pipeline is not None:
  58. pipelines.free_instance(model.root_folder)
  59. finally:
  60. if database_copy is not None:
  61. database_copy.close()