from typing import List
from typing import Union

from flask import abort
from flask import make_response
from flask import request
from flask.views import View

from pycs import db
from pycs.database.Project import Project
from pycs.frontend.notifications.NotificationList import NotificationList
from pycs.frontend.notifications.NotificationManager import NotificationManager
from pycs.interfaces.MediaFile import MediaFile
from pycs.interfaces.MediaStorage import MediaStorage
from pycs.jobs.JobGroupBusyException import JobGroupBusyException
from pycs.jobs.JobRunner import JobRunner
from pycs.util.PipelineCache import PipelineCache


class PredictModel(View):
    """
    load a model and create predictions
    """
    # pylint: disable=arguments-differ
    methods = ['POST']

    def __init__(self, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
        # pylint: disable=invalid-name
        self.nm = nm
        self.jobs = jobs
        self.pipelines = pipelines

    def dispatch_request(self, project_id):
        # extract request data
        data = request.get_json(force=True)

        predict = data.get('predict')

        if predict is None:
            abort(400, "predict argument is missing")

        if predict not in ['all', 'new']:
            abort(400, "predict must be either 'all' or 'new'")

        # find project
        project = Project.get_or_404(project_id)

        # create job
        try:
            notifications = NotificationList(self.nm)

            self.jobs.run(project,
                          'Model Interaction',
                          f'{project.name} (create predictions)',
                          f'{project.id}/model-interaction',
                          PredictModel.load_and_predict,
                          self.pipelines, notifications,
                          project.id, predict,
                          progress=self.progress)

        except JobGroupBusyException:
            abort(400, "Model prediction is already running")

        return make_response()

    @staticmethod
    def load_and_predict(pipelines: PipelineCache,
                         notifications: NotificationList,
                         project_id: int, file_filter: Union[str, List[int]]):
        """
        load the pipeline and call the execute function

        :param database: database object
        :param pipelines: pipeline cache
        :param notifications: notification object
        :param project_id: project id
        :param file_filter: list of file ids or 'new' / 'all'
        :return:
        """
        pipeline = None

        # create new database instance
        project = Project.query.get(project_id)
        model_root = project.model.root_folder
        storage = MediaStorage(project_id, notifications)

        # create a list of MediaFile
        if isinstance(file_filter, str):
            if file_filter == 'new':
                files = project.files_without_results()
                length = project.count_files_without_results()

            else:
                files = project.files.all()
                length = project.files.count()

        else:
            files = [project.file(identifier) for identifier in file_filter]
            length = len(files)


        media_files = map(lambda f: MediaFile(f, notifications), files)
        # load pipeline
        try:
            pipeline = pipelines.load_from_root_folder(project, model_root)

            # iterate over media files
            index = 0
            for file in media_files:
                # remove old predictions
                file.remove_predictions()

                # create new predictions
                pipeline.execute(storage, file)

                # commit changes and yield progress
                db.session.commit()
                yield index / length, notifications

                index += 1

        finally:
            if pipeline is not None:
                pipelines.free_instance(model_root)


    @staticmethod
    def progress(progress: float, notifications: NotificationList):
        """
        fire notifications from the correct thread

        :param progress: [0, 1]
        :param notifications: Notificationlist
        :return: progress
        """
        notifications.fire()
        return progress