import typing as T
import uuid

import cv2
import numpy as np

from flask import abort
from flask import make_response
from flask import request
from flask.views import View

from pycs.database.File import File
from pycs.database.Result import Result
from pycs.frontend.notifications.NotificationManager import NotificationManager
from pycs.jobs.JobGroupBusyException import JobGroupBusyException
from pycs.jobs.JobRunner import JobRunner

class EstimateBoundingBox(View):
    """
    create a result for a file
    """
    # pylint: disable=arguments-differ
    methods = ['POST']

    def __init__(self, nm: NotificationManager, jobs: JobRunner,):
        # pylint: disable=invalid-name
        self.nm = nm
        self.jobs = jobs

    def dispatch_request(self, file_id: int):

        file = File.get_or_404(file_id)
        request_data = request.get_json(force=True)
        if 'x' not in request_data or 'y' not in request_data:
            abort(400, "coordinates for the estimation are missing")

        x, y = map(request_data.get, "xy")

        # get project
        project = file.project
        try:
            rnd = str(uuid.uuid4())[:10]
            self.jobs.run(project,
                          "Estimation",
                          f'{project.name} (create predictions)',
                          f"{project.id}/estimation/{rnd}",
                          estimate,
                          file.id, x, y,
                          result=self.nm.create_result
                          )

        except JobGroupBusyException:
            abort(400, "Job is already running!")

        return make_response()


def estimate(file_id: int, x: float, y: float) -> Result:
    """ estimation function """

    file = File.query.get(file_id)

    image = cv2.imread(file.absolute_path, cv2.IMREAD_GRAYSCALE)

    h, w = image.shape
    pos = int(x * w), int(y * h)
    x0, y0, x1, y1 = detect(image, pos,
                            window_size=1000,
                            pixel_delta=50,
                            enlarge=1e-2,
                           )

    data = dict(
       x=x0 / w,
       y=y0 / h,
       w=(x1-x0) / w,
       h=(y1-y0) / h
    )

    return file.create_result('pipeline', 'bounding-box', label=None, data=data)

def detect(image: np.ndarray,
           pos: T.Tuple[int, int],
           window_size: int = 1000,
           pixel_delta: int = 0,
           enlarge: float = -1) -> T.Tuple[int, int, int, int]:
    """ detection function """
    # image = blur(image, 3)
    x, y = pos
    pixel = image[y, x]

    min_pix, max_pix = pixel - pixel_delta, pixel + pixel_delta

    mask = np.logical_and(min_pix < image, image < max_pix).astype(np.float32)
    # mask = open_close(mask)
    # mask = blur(mask)

    pad = window_size // 2
    mask = np.pad(mask, pad, mode="constant")
    window = mask[y: y + window_size, x: x + window_size]

    sum_x, sum_y = window.sum(axis=0), window.sum(axis=1)

    enlarge = int(enlarge * max(image.shape))
    (x0, x1), (y0, y1) = get_borders(sum_x, enlarge), get_borders(sum_y, enlarge)

    x0 = max(x + x0 - pad, 0)
    y0 = max(y + y0 - pad, 0)

    x1 = min(x + x1 - pad, image.shape[1])
    y1 = min(y + y1 - pad, image.shape[0])

    return x0, y0, x1, y1

def get_borders(arr, enlarge: int, eps=5e-1):
    """ returns borders based on coordinate extrema """
    mid = len(arr) // 2

    arr0, arr1 = arr[:mid], arr[mid:]

    thresh = arr[mid] * eps

    lowers = np.where(arr0 < thresh)[0]
    lower = 0 if len(lowers) == 0 else lowers[-1]

    uppers = np.where(arr1 < thresh)[0]
    upper = arr1.argmin() if len(uppers) == 0 else uppers[0]

    # since the second half starts after the first
    upper = len(arr0) + upper

    if enlarge > 0:
        lower = max(lower - enlarge, 0)
        upper = min(upper + enlarge, len(arr)-1)

    return int(lower), int(upper)