PipelineManager.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from os import path
  2. from eventlet import tpool
  3. from pycs.pipeline.Fit import Fit
  4. from pycs.pipeline.Job import Job
  5. class PipelineManager:
  6. def __init__(self, project):
  7. code_path = path.join(project['model']['path'], project['model']['code']['module'])
  8. module_name = code_path.replace('/', '.').replace('\\', '.')
  9. class_name = project['model']['code']['class']
  10. mod = __import__(module_name, fromlist=[class_name])
  11. cl = getattr(mod, class_name)
  12. self.project = project
  13. self.pipeline = cl(project['model']['path'], project['model'])
  14. def close(self):
  15. print('PipelineManager', 'close')
  16. self.pipeline.close()
  17. def run(self, media_file):
  18. # create job list
  19. # TODO update job progress
  20. job = Job(self.project['id'], media_file)
  21. result = tpool.execute(lambda p, j: p.execute(j), self.pipeline, job)
  22. # remove existing pipeline predictions from media_fle
  23. media_file.remove_pipeline_results()
  24. # add new predictions
  25. for prediction in result:
  26. media_file.add_result(prediction, origin='pipeline')
  27. def fit(self):
  28. print('PipelineManager', 'fit')
  29. data = []
  30. for identifier in self.project['data']:
  31. fit = Fit(self.project['id'], self.project['data'][identifier])
  32. data.append(fit)
  33. self.pipeline.fit(data)