from os import path
from typing import List
from urllib.request import urlretrieve

import cv2

from pycs.interfaces.MediaFile import MediaFile
from pycs.interfaces.MediaStorage import MediaStorage
from pycs.interfaces.Pipeline import Pipeline as Interface


class Pipeline(Interface):
    URL = 'https://raw.githubusercontent.com/opencv/opencv/master/data/haarcascades/haarcascade_frontalface_default.xml'

    def __init__(self, root_folder, distribution):
        print('hcffdv1 init')

        # get path to xml file
        xml_file = path.join(root_folder, 'haarcascade_frontalface_default.xml')

        # download
        if not path.exists(xml_file):
            urlretrieve(self.URL, xml_file)

        # load
        self.face_cascade = cv2.CascadeClassifier(xml_file)

    def close(self):
        print('hcffdv1 close')

    def collections(self) -> List[dict]:
        return [
            self.create_collection('face', 'face detected', autoselect=True),
            self.create_collection('none', 'no face detected')
        ]

    def execute(self, storage: MediaStorage, file: MediaFile):
        print('hcffdv1 execute')

        # load file and analyze frames
        found = False

        if file.type == 'image':
            found = self.__find(file, cv2.imread(file.path))
        else:
            video = cv2.VideoCapture(file.path)
            index = 0

            ret, image = video.read()
            while ret:
                if self.__find(file, image, index):
                    found = True

                ret, image = video.read()
                index += 1

            video.release()

        # set file collection
        if found:
            file.set_collection('face')
        else:
            file.set_collection('none')

    def __find(self, file: MediaFile, image, frame=None):
        # convert to grayscale, scale down
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        height, width = gray.shape

        scale_factor = min(2048 / width, 2048 / height, 1.0)
        scale_height, scale_width = int(height * scale_factor), int(width * scale_factor)
        scaled = cv2.resize(gray, (scale_width, scale_height))

        # detect faces
        faces = self.face_cascade.detectMultiScale(
            scaled,
            scaleFactor=1.1,
            minNeighbors=5,
            minSize=(192, 192)
        )

        # add faces to results
        for x, y, w, h in faces:
            file.add_bounding_box(x / scale_width,
                                  y / scale_height,
                                  w / scale_width,
                                  h / scale_height,
                                  frame=frame)

        return len(faces) > 0