123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- from contextlib import closing
- from flask import make_response, request, abort
- from flask.views import View
- from pycs.database.Database import Database
- from pycs.frontend.notifications.NotificationManager import NotificationManager
- from pycs.interfaces.MediaFile import MediaFile
- from pycs.jobs.JobGroupBusyException import JobGroupBusyException
- from pycs.jobs.JobRunner import JobRunner
- from pycs.util.PipelineUtil import load_from_root_folder as load_pipeline
- class PredictModel(View):
- """
- load a model and create predictions
- """
- # pylint: disable=arguments-differ
- methods = ['POST']
- def __init__(self, db: Database, nm: NotificationManager, jobs: JobRunner):
- # pylint: disable=invalid-name
- self.db = db
- self.nm = nm
- self.jobs = jobs
- def dispatch_request(self, project_id):
- # extract request data
- data = request.get_json(force=True)
- if 'predict' not in data or data['predict'] not in ['all', 'new']:
- return abort(400)
- # find project
- project = self.db.project(project_id)
- if project is None:
- return abort(404)
- # get model
- model = project.model()
- # get data and results
- if data['predict'] == 'new':
- files = project.files_without_results()
- else:
- files = project.files()
- objects = list(map(MediaFile, files))
- # create job
- def store(index, length, result):
- # get file from list
- file = files[index]
- # start transaction
- with self.db:
- # remove current results from file
- for remove in file.results():
- if remove.origin == 'pipeline':
- remove.remove()
- self.nm.remove_result(remove)
- # iterate over result entries
- for entry in result:
- # extract entry type
- entry_type = entry['type']
- del entry['type']
- # update file collection
- if entry_type == 'collection':
- file.set_collection_by_reference(entry['reference'])
- self.nm.edit_file(file)
- continue
- # extract label from entry
- if 'label' in entry:
- label = entry['label']
- del entry['label']
- else:
- label = None
- # if entry_type == 'labeled-image':
- # for remove in file.results():
- # remove.remove()
- # self.nm.remove_result(remove)
- # add result
- created = files[index].create_result('pipeline', entry_type, label, entry)
- self.nm.create_result(created)
- return (index + 1) / length
- try:
- self.jobs.run(project,
- 'Model Interaction',
- f'{project.name} (create predictions)',
- f'{project.name}/model-interaction',
- self.load_and_predict, model, objects,
- progress=store)
- except JobGroupBusyException:
- return abort(400)
- return make_response()
- @staticmethod
- def load_and_predict(model, files):
- with closing(load_pipeline(model.root_folder)) as pipeline:
- length = len(files)
- for index in range(length):
- result = pipeline.execute(files[index])
- yield index, length, result
|