123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- from typing import List
- from typing import Union
- from flask import abort
- from flask import make_response
- from flask import request
- from flask.views import View
- from pycs import db
- from pycs.database.Project import Project
- from pycs.frontend.notifications.NotificationList import NotificationList
- from pycs.frontend.notifications.NotificationManager import NotificationManager
- from pycs.interfaces.MediaFile import MediaFile
- from pycs.interfaces.MediaStorage import MediaStorage
- from pycs.jobs.JobGroupBusyException import JobGroupBusyException
- from pycs.jobs.JobRunner import JobRunner
- from pycs.util.PipelineCache import PipelineCache
- class PredictModel(View):
- """
- load a model and create predictions
- """
- # pylint: disable=arguments-differ
- methods = ['POST']
- def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
- # pylint: disable=invalid-name
- self.nm = nm
- self.jobs = jobs
- self.pipelines = pipelines
- def dispatch_request(self, project_id):
- project = Project.get_or_404(project_id)
- # extract request data
- data = request.get_json(force=True)
- predict = data.get('predict')
- if predict is None:
- abort(400, "predict argument is missing")
- if predict not in ['all', 'new']:
- abort(400, "predict must be either 'all' or 'new'")
- # create job
- try:
- notifications = NotificationList(self.nm)
- self.jobs.run(project,
- 'Model Interaction',
- f'{project.name} (create predictions)',
- f'{project.id}/model-interaction',
- PredictModel.load_and_predict,
- self.pipelines, notifications,
- project.id, predict,
- progress=self.progress)
- except JobGroupBusyException:
- abort(400, "Model prediction is already running")
- return make_response()
- @staticmethod
- def load_and_predict(pipelines: PipelineCache,
- notifications: NotificationList,
- project_id: int, file_filter: Union[str, List[int]]):
- """
- load the pipeline and call the execute function
- :param database: database object
- :param pipelines: pipeline cache
- :param notifications: notification object
- :param project_id: project id
- :param file_filter: list of file ids or 'new' / 'all'
- :return:
- """
- pipeline = None
- # create new database instance
- project = Project.query.get(project_id)
- model_root = project.model.root_folder
- storage = MediaStorage(project_id, notifications)
- # create a list of MediaFile
- if isinstance(file_filter, str):
- if file_filter == 'new':
- files = project.files_without_results()
- length = project.count_files_without_results()
- else:
- files = project.files.all()
- length = project.files.count()
- else:
- files = [project.file(identifier) for identifier in file_filter]
- length = len(files)
- media_files = map(lambda f: MediaFile(f, notifications), files)
- # load pipeline
- try:
- pipeline = pipelines.load_from_root_folder(project_id, model_root)
- # iterate over media files
- index = 0
- for file in media_files:
- # remove old predictions
- file.remove_predictions()
- # create new predictions
- pipeline.execute(storage, file)
- # commit changes and yield progress
- db.session.commit()
- yield index / length, notifications
- index += 1
- finally:
- if pipeline is not None:
- pipelines.free_instance(model_root)
- @staticmethod
- def progress(progress: float, notifications: NotificationList):
- """
- fire notifications from the correct thread
- :param progress: [0, 1]
- :param notifications: Notificationlist
- :return: progress
- """
- notifications.fire()
- return progress
|