1
1

FitModel.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from flask import make_response, request, abort
  2. from flask.views import View
  3. from pycs.database.Database import Database
  4. from pycs.interfaces.MediaStorage import MediaStorage
  5. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  6. from pycs.jobs.JobRunner import JobRunner
  7. from pycs.util.PipelineUtil import load_from_root_folder as load_pipeline
  8. class FitModel(View):
  9. """
  10. use annotated data to fit a model
  11. """
  12. # pylint: disable=arguments-differ
  13. methods = ['POST']
  14. def __init__(self, db: Database, jobs: JobRunner):
  15. # pylint: disable=invalid-name
  16. self.db = db
  17. self.jobs = jobs
  18. def dispatch_request(self, project_id):
  19. # extract request data
  20. data = request.get_json(force=True)
  21. if 'fit' not in data or data['fit'] is not True:
  22. return abort(400)
  23. # find project
  24. project = self.db.project(project_id)
  25. if project is None:
  26. return abort(404)
  27. # create job
  28. try:
  29. self.jobs.run(project,
  30. 'Model Interaction',
  31. f'{project.name} (fit model with new data)',
  32. f'{project.name}/model-interaction',
  33. self.load_and_fit, self.db, project.identifier)
  34. except JobGroupBusyException:
  35. return abort(400)
  36. return make_response()
  37. @staticmethod
  38. def load_and_fit(database: Database, project_id: int):
  39. db = None
  40. pipeline = None
  41. # create new database instance
  42. try:
  43. db = database.copy()
  44. project = db.project(project_id)
  45. model = project.model()
  46. storage = MediaStorage(db, project_id)
  47. # load pipeline
  48. try:
  49. pipeline = load_pipeline(model.root_folder)
  50. yield from pipeline.fit(storage)
  51. except TypeError:
  52. pass
  53. finally:
  54. if pipeline is not None:
  55. pipeline.close()
  56. finally:
  57. if db is not None:
  58. db.close()