FitModel.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from flask import make_response, request, abort
  2. from flask.views import View
  3. from pycs.database.Database import Database
  4. from pycs.database.Project import Project
  5. from pycs.interfaces.MediaStorage import MediaStorage
  6. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  7. from pycs.jobs.JobRunner import JobRunner
  8. from pycs.util.PipelineCache import PipelineCache
  9. class FitModel(View):
  10. """
  11. use annotated data to fit a model
  12. """
  13. # pylint: disable=arguments-differ
  14. methods = ['POST']
  15. def __init__(self, db: Database, jobs: JobRunner, pipelines: PipelineCache):
  16. # pylint: disable=invalid-name
  17. self.db = db
  18. self.jobs = jobs
  19. self.pipelines = pipelines
  20. def dispatch_request(self, project_id):
  21. # extract request data
  22. data = request.get_json(force=True)
  23. if 'fit' not in data or data['fit'] is not True:
  24. return abort(400)
  25. # find project
  26. project = Project.query.get(project_id)
  27. if project is None:
  28. return abort(404)
  29. # create job
  30. try:
  31. self.jobs.run(project,
  32. 'Model Interaction',
  33. f'{project.name} (fit model with new data)',
  34. f'{project.name}/model-interaction',
  35. self.load_and_fit, self.db, project.id)
  36. except JobGroupBusyException:
  37. return abort(400)
  38. return make_response()
  39. @staticmethod
  40. def load_and_fit(database: Database, pipelines: PipelineCache, project_id: int):
  41. pipeline = None
  42. project = Project.query.get(project_id)
  43. model = project.model()
  44. storage = MediaStorage(project_id)
  45. # load pipeline
  46. try:
  47. pipeline = pipelines.load_from_root_folder(project, model.root_folder)
  48. yield from pipeline.fit(storage)
  49. except TypeError:
  50. pass
  51. finally:
  52. if pipeline is not None:
  53. pipelines.free_instance(model.root_folder)