123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- from os import path
- from eventlet import tpool
- from pycs.pipeline.Fit import Fit
- from pycs.pipeline.Job import Job
- class PipelineManager:
- def __init__(self, project):
- code_path = path.join(project['model']['path'], project['model']['code']['module'])
- module_name = code_path.replace('/', '.').replace('\\', '.')
- class_name = project['model']['code']['class']
- mod = __import__(module_name, fromlist=[class_name])
- cl = getattr(mod, class_name)
- self.project = project
- self.pipeline = cl(project['model']['path'], project['model'])
- def close(self):
- print('PipelineManager', 'close')
- self.pipeline.close()
- def run(self, media_file):
- # create job list
- # TODO update job progress
- job = Job(self.project, media_file)
- result = tpool.execute(lambda p, j: p.execute(j), self.pipeline, job)
- # remove existing pipeline predictions from media_fle
- media_file.remove_pipeline_results()
- # add new predictions
- for prediction in result:
- media_file.add_result(prediction, origin='pipeline')
- def fit(self):
- print('PipelineManager', 'fit')
- data = []
- for identifier in self.project['data']:
- fit = Fit(self.project, self.project['data'][identifier])
- data.append(fit)
- for key in self.project.unmanaged_files:
- obj = self.project.unmanaged_files[key].get_data()
- fit = Fit(self.project, obj)
- data.append(fit)
- self.pipeline.fit(data)
|