from typing import Union, List

from flask import make_response, request, abort
from flask.views import View

from pycs.database.Database import Database
from pycs.database.File import File
from pycs.frontend.notifications.NotificationList import NotificationList
from pycs.frontend.notifications.NotificationManager import NotificationManager
from pycs.interfaces.MediaFile import MediaFile
from pycs.interfaces.MediaStorage import MediaStorage
from pycs.jobs.JobGroupBusyException import JobGroupBusyException
from pycs.jobs.JobRunner import JobRunner
from pycs.util.PipelineCache import PipelineCache


class PredictModel(View):
    """
    load a model and create predictions
    """
    # pylint: disable=arguments-differ
    methods = ['POST']

    def __init__(self,
                 db: Database, nm: NotificationManager, jobs: JobRunner, pipelines: PipelineCache):
        # pylint: disable=invalid-name
        self.db = db
        self.nm = nm
        self.jobs = jobs
        self.pipelines = pipelines

    def dispatch_request(self, project_id):
        # extract request data
        data = request.get_json(force=True)

        if 'predict' not in data or data['predict'] not in ['all', 'new']:
            return abort(400)

        # find project
        project = self.db.project(project_id)
        if project is None:
            return abort(404)

        # create job
        try:
            notifications = NotificationList(self.nm)

            self.jobs.run(project,
                          'Model Interaction',
                          f'{project.name} (create predictions)',
                          f'{project.name}/model-interaction',
                          self.load_and_predict,
                          self.db, self.pipelines, notifications,
                          project.identifier, data['predict'],
                          progress=self.progress)
        except JobGroupBusyException:
            return abort(400)

        return make_response()

    @staticmethod
    def load_and_predict(database: Database, pipelines: PipelineCache,
                         notifications: NotificationList,
                         project_id: int, file_filter: Union[str, List[File]]):
        """
        load the pipeline and call the execute function

        :param database: database object
        :param pipelines: pipeline cache
        :param notifications: notification object
        :param project_id: project id
        :param file_filter: list of files or 'new' / 'all'
        :return:
        """
        database_copy = None
        pipeline = None

        # create new database instance
        try:
            database_copy = database.copy()
            project = database_copy.project(project_id)
            model = project.model()
            storage = MediaStorage(database_copy, project_id, notifications)

            # create a list of MediaFile
            if isinstance(file_filter, str):
                if file_filter == 'new':
                    length = project.count_files_without_results()
                    files = map(lambda f: MediaFile(f, notifications),
                                project.files_without_results())
                else:
                    length = project.count_files()
                    files = map(lambda f: MediaFile(f, notifications),
                                project.files())
            else:
                files = map(lambda f: MediaFile(project.file(f.identifier), notifications),
                            file_filter)
                length = len(file_filter)

            # load pipeline
            try:
                pipeline = pipelines.load_from_root_folder(project, model.root_folder)

                # iterate over files
                index = 0
                for file in files:
                    # remove old predictions
                    file.remove_predictions()

                    # create new predictions
                    pipeline.execute(storage, file)

                    # commit changes and yield progress
                    database_copy.commit()
                    yield index / length, notifications

                    index += 1
            finally:
                if pipeline is not None:
                    pipelines.free_instance(model.root_folder)
        finally:
            if database_copy is not None:
                database_copy.close()

    @staticmethod
    def progress(progress: float, notifications: NotificationList):
        """
        fire notifications from the correct thread

        :param progress: [0, 1]
        :param notifications: Notificationlist
        :return: progress
        """
        notifications.fire()
        return progress