from json import load
from os import path, mkdir, listdir
from os.path import splitext
from uuid import uuid1

from eventlet import spawn_after

from pycs.observable import ObservableDict
from pycs.pipeline.PipelineManager import PipelineManager
from pycs.projects.ImageFile import ImageFile
from pycs.projects.UnmanagedImageFile import UnmanagedImageFile
from pycs.projects.UnmanagedVideoFile import UnmanagedVideoFile
from pycs.projects.VideoFile import VideoFile
from pycs.util.RecursiveDictionary import set_recursive


class Project(ObservableDict):
    DEFAULT_PIPELINE_TIMEOUT = 120

    def __init__(self, obj: dict, parent):
        self.pipeline_manager = None
        self.quit_pipeline_thread = None

        self.unmanaged_files_keys = []
        self.unmanaged_files = {}

        # ensure all required object keys are available
        for key in ['data', 'labels', 'jobs']:
            if key not in obj.keys():
                obj[key] = {}

        # load model data
        folder = path.join('projects', obj['id'], 'model')
        with open(path.join(folder, 'distribution.json'), 'r') as file:
            model = load(file)
            model['path'] = folder

            obj['model'] = model

        # save data as MediaFile objects
        if obj['unmanaged'] is None:
            for key in obj['data'].keys():
                obj['data'][key] = self.create_media_file(obj['data'][key])

        # handle unmanaged files
        else:
            prev = None
            for file in listdir(obj['unmanaged']):
                uuid, ext = splitext(file)

                next = {
                    'id': uuid,
                    'extension': ext
                }
                next = self.create_media_file(next, unmanaged=True)

                if prev is not None:
                    next.prev(prev)
                    prev.next(next)

                prev = next

                self.unmanaged_files_keys.append(uuid)
                self.unmanaged_files[uuid] = next

            length = len(self.unmanaged_files_keys)
            for key in self.unmanaged_files:
                self.unmanaged_files[key].length(length)

        # initialize super
        super().__init__(obj, parent)

        # create data and temp
        data_path = path.join('projects', self['id'], 'data')
        if not path.exists(data_path):
            mkdir(data_path)

        temp_path = path.join('projects', self['id'], 'temp')
        if not path.exists(temp_path):
            mkdir(temp_path)

        # subscribe to changes to write to disk afterwards
        self.subscribe(lambda d, k: self.parent.write_project(self['id']))

    def update_properties(self, update):
        set_recursive(update, self)

    def get_media_file(self, identifier):
        if self['unmanaged']:
            if identifier not in self.unmanaged_files_keys:
                return None

            return self.unmanaged_files[identifier]
        else:
            if identifier not in self['data'].keys():
                return None

            return self['data'][identifier]

    def new_media_file_path(self):
        return path.join('projects', self['id'], 'data'), str(uuid1())

    def create_media_file(self, file, unmanaged=False):
        # TODO check file extension
        # TODO determine type
        # TODO filter supported types
        if file['extension'] in ['.jpg', '.png']:
            if unmanaged:
                return UnmanagedImageFile(file, self)
            else:
                return ImageFile(file, self)
        if file['extension'] in ['.mp4']:
            if unmanaged:
                return UnmanagedVideoFile(file, self)
            else:
                return VideoFile(file, self)

        raise NotImplementedError

    def add_media_file(self, uuid, name, extension, size, created):
        file = {
            'id': uuid,
            'name': name,
            'extension': extension,
            'size': size,
            'created': created
        }
        self['data'][file['id']] = self.create_media_file(file)

    def remove_media_file(self, file_id):
        del self['data'][file_id]

    def add_label(self, name):
        label_uuid = str(uuid1())
        self['labels'][label_uuid] = {
            'id': label_uuid,
            'name': name
        }

    def update_label(self, identifier, name):
        if identifier in self['labels']:
            self['labels'][identifier]['name'] = name

    def remove_label(self, identifier):
        # abort if identifier is unknown
        if identifier not in self['labels']:
            return

        # remove label from data elements
        remove = list()

        for data in self['data']:
            for pred in self['data'][data]['predictionResults']:
                if 'label' in self['data'][data]['predictionResults'][pred]:
                    if self['data'][data]['predictionResults'][pred]['label'] == identifier:
                        remove.append((data, pred))

        for t in remove:
            del self['data'][t[0]]['predictionResults'][t[1]]

        # remove label from list
        del self['labels'][identifier]

    def predict(self, identifiers):
        # create pipeline
        pipeline = self.__create_pipeline()

        # run jobs
        if self['unmanaged'] is None:
            for file_id in identifiers:
                if file_id in self['data'].keys():
                    pipeline.run(self['data'][file_id])
        else:
            for file_id in identifiers:
                if file_id in self.unmanaged_files:
                    pipeline.run(self.unmanaged_files[file_id])

        # schedule timeout thread
        self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)

    def fit(self):
        # create pipeline
        pipeline = self.__create_pipeline()

        # run fit
        pipeline.fit()

        # schedule timeout thread
        self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)

    def __create_pipeline(self):
        # abort pipeline termination
        self.__quit_pipeline_thread()

        # create pipeline if it does not exist already
        if self.pipeline_manager is None:
            self.pipeline_manager = PipelineManager(self)

        return self.pipeline_manager

    def __quit_pipeline(self):
        if self.pipeline_manager is not None:
            self.pipeline_manager.close()
            self.pipeline_manager = None
            self.quit_pipeline_thread = None

    def __create_quit_pipeline_thread(self):
        # abort pipeline termination
        self.__quit_pipeline_thread()

        # create new thread
        self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)

    def __quit_pipeline_thread(self):
        if self.quit_pipeline_thread is not None:
            self.quit_pipeline_thread.cancel()
            self.quit_pipeline_thread = None