import cv2
import io
import numpy as np
import os
import uuid

from PIL import Image
from flask import url_for
from pathlib import Path

from pycs.database.File import File
from pycs.util.FileOperations import BoundingBox

from tests.base import pаtch_tpool_execute
from tests.client.label_tests import _BaseLabelTests


class _BaseFileTests(_BaseLabelTests):

    def setupModels(self):
        super().setupModels()
        root = Path(self.project.root_folder)
        data_root = Path(self.project.data_folder)

        for folder in [data_root, root / "temp"]:
            folder.mkdir(exist_ok=True, parents=True)

    def _get_dummy_image_bytes(self, size=(4000, 6000, 3)):
        byteImgIO = io.BytesIO()
        byteImg = Image.fromarray(np.zeros(size).astype(np.uint8))
        byteImg.save(byteImgIO, "JPEG")
        byteImgIO.seek(0)
        file_content = byteImgIO.read()

        return file_content

    def _create_dummy_image(self, file_name, size=(4000, 6000, 3)):
        absolute_path = os.path.join(self.project.data_folder, file_name)
        file_content = self._get_dummy_image_bytes(size=size)
        with open(absolute_path, "wb") as f:
            f.write(file_content)

        return absolute_path, file_content

class FileCreationTests(_BaseFileTests):

    @pаtch_tpool_execute
    def test_file_upload_project_with_external_data(self, mocked_execute=None):
        file_content = b"some content+1"
        url = url_for("upload_file", project_id=self.project.id)

        self.assertEqual(0, File.query.count())

        self.project.external_data = True
        self.project.commit()

        self.post(url,
            data=dict(file=(io.BytesIO(file_content), "image.jpg")),
            content_type="multipart/form-data",
            status_code=400,
        )

        self.assertEqual(0, File.query.count())

    @pаtch_tpool_execute
    def test_file_upload(self, mocked_execute=None):

        url = url_for("upload_file", project_id=4242)
        self.post(url, data=dict(), status_code=404)

        # Creating a dummy image with proper dummy content.
        file_content = self._get_dummy_image_bytes()
        url = url_for("upload_file", project_id=self.project.id)

        self.assertEqual(0, File.query.count())

        self.post(url, data=dict(),
            status_code=400)
        self.assertEqual(0, File.query.count())

        self.post(url,
            data=dict(file=(io.BytesIO(file_content), "image.jpg")),
            content_type="multipart/form-data",
        )

        self.assertEqual(1, File.query.count())

        # this does not work, if we do not set the CONTENT_LENGTH by ourself
        # file = File.query.first()
        # self.assertEqual(len(file_content), file.size)


class FileDeletionTests(_BaseFileTests):

    def test_file_removal(self):

        self._create_dummy_image("image.jpg")

        file_uuid = str(uuid.uuid1())
        file, is_new = self.project.add_file(
            uuid=file_uuid,
            file_type="image",
            name=f"name",
            filename=f"image",
            extension=".jpg",
            size=32*1024,
        )

        self.assertTrue(is_new)

        self.assertEqual(1, self.project.files.count())

        self.assertTrue(os.path.exists(file.absolute_path))

        url = url_for("remove_file", file_id=file.id)
        self.post(url, json=dict(), status_code=400)
        self.post(url, json=dict(remove=False), status_code=400)
        self.post(url, json=dict(remove=True))
        self.assertEqual(0, self.project.files.count())
        self.assertFalse(os.path.exists(file.absolute_path))

        url = url_for("remove_file", file_id=4242)
        self.post(url, json=dict(remove=True), status_code=404)

    def test_file_removal_from_project_with_external_data(self):

        self._create_dummy_image("image.jpg")

        file_uuid = str(uuid.uuid1())
        file, is_new = self.project.add_file(
            uuid=file_uuid,
            file_type="image",
            name=f"name",
            filename=f"image",
            extension=".jpg",
            size=32*1024,
        )

        self.assertTrue(is_new)

        self.project.external_data = True
        self.assertTrue(os.path.exists(file.absolute_path))
        url = url_for("remove_file", file_id=file.id)

        self.assertEqual(1, self.project.files.count())
        self.post(url, json=dict(remove=True), status_code=400)
        self.assertEqual(1, self.project.files.count())


class FileGettingTests(_BaseFileTests):

    def test_get_file_getting(self):

        file_uuid = str(uuid.uuid1())
        file, is_new = self.project.add_file(
            uuid=file_uuid,
            file_type="image",
            name=f"name",
            filename=f"image",
            extension=".jpg",
            size=32*1024,
        )

        self.assertTrue(is_new)
        self.assertEqual(1, self.project.files.count())

        url = url_for("get_file", file_id=file.id)

        # without an actual file, this GET request returns 404
        self.get(url, status_code=404)

        _, content = self._create_dummy_image("image.jpg")

        response = self.get(url)

        self.assertFalse(response.is_json)
        self.assertEqual(content, response.data)

        response.close()

    def test_get_prev_next_file(self):

        for i in range(1, 6):
            file_uuid = str(uuid.uuid1())
            file, is_new = self.project.add_file(
                uuid=file_uuid,
                file_type="image",
                name=f"name_{i}",
                filename=f"image_{i}",
                extension=".jpg",
                size=32*1024,
            )

            self.assertTrue(is_new)
            with open(file.absolute_path, "wb") as f:
                f.write(b"some content")

        self.assertEqual(5, self.project.files.count())
        files = self.project.files.all()

        url = url_for("get_previous_and_next_file", file_id=4542)
        self.get(url, status_code=404)

        for i, file in enumerate(files):
            p_file, n_file = None, None

            if i != 0:
                p_file = files[i-1].serialize()

            if i < len(files)-1:
                n_file = files[i+1].serialize()


            url = url_for("get_previous_and_next_file", file_id=file.id)

            response = self.get(url)
            self.assertTrue(response.is_json)

            content_should = dict(
                current=file.serialize(),
                next=n_file,
                nextInCollection=n_file,
                previous=p_file,
                previousInCollection=p_file,
            )

            self.assertDictEqual(content_should, response.json)

        files[1].delete()
        file = files[2]
        p_file, n_file = files[0], files[3]
        url = url_for("get_previous_and_next_file", file_id=file.id)

        response = self.get(url)
        self.assertTrue(response.is_json)

        content_should = dict(
            current=file.serialize(),
            next=n_file.serialize(),
            nextInCollection=n_file.serialize(),
            previous=p_file.serialize(),
            previousInCollection=p_file.serialize(),
        )

        self.assertDictEqual(content_should, response.json)

        files[3].delete()
        file = files[2]
        p_file, n_file = files[0], files[4]
        url = url_for("get_previous_and_next_file", file_id=file.id)

        response = self.get(url)
        self.assertTrue(response.is_json)

        content_should = dict(
            current=file.serialize(),
            next=n_file.serialize(),
            nextInCollection=n_file.serialize(),
            previous=p_file.serialize(),
            previousInCollection=p_file.serialize(),
        )

        self.assertDictEqual(content_should, response.json)



class FileResizingTests(_BaseFileTests):

    def _add_image(self, shape, file: File):
        image = np.random.randint(0, 256, shape).astype(np.uint8)

        im = Image.fromarray(image)

        im.save(file.absolute_path)
        self.assertTrue(os.path.exists(file.absolute_path))
        return image

    def _compare_images(self, im0, im1, threshold=1e-3):
        im0, im1 = im0 / 255, im1 / 255
        mse = np.mean((im0 - im1)**2)
        self.assertLess(mse, threshold)

    @pаtch_tpool_execute
    def test_resize_image(self, mocked_execute):

        self.get(url_for("get_resized_file", file_id=4242, resolution=300), status_code=404).close()

        file_uuid = str(uuid.uuid1())
        file, is_new = self.project.add_file(
            uuid=file_uuid,
            file_type="image",
            name=f"name",
            filename=f"image",
            extension=".png",
            size=32*1024,
        )

        self.assertTrue(is_new)

        image = self._add_image((300, 300), file)

        for upscale in [300, 1200, 500, 320]:
            url = url_for("get_resized_file", file_id=file.id, resolution=upscale)
            response = self.get(url)

            self.assertFalse(response.is_json)

            returned_im = _im_from_bytes(response.data)
            response.close()

            self.assertEqual(image.shape, returned_im.shape)
            self._compare_images(image, returned_im)


        # repeat the last scale two times to get the cached resized image
        for downscale in [299, 200, 150, 32, 32]:
            sm_image = _resize(image, downscale)

            url = url_for("get_resized_file", file_id=file.id, resolution=downscale)
            response = self.get(url)

            self.assertFalse(response.is_json)

            returned_im = _im_from_bytes(response.data)
            response.close()

            self.assertEqual(sm_image.shape, returned_im.shape)
            self._compare_images(sm_image, returned_im)

            del sm_image

    @pаtch_tpool_execute
    def test_resize_image_not_found(self, mocked_execute):

        file_uuid = str(uuid.uuid1())
        file, is_new = self.project.add_file(
            uuid=file_uuid,
            file_type="image",
            name=f"name",
            filename=f"image",
            extension=".png",
            size=32*1024,
        )

        self.assertTrue(is_new)

        image = self._add_image((300, 300), file)

        save = file.path
        file.path = "/some/nonexisting/path"
        file.commit()
        url = url_for("get_resized_file", file_id=file.id, resolution=300)
        response = self.get(url, status_code=404)

        file.path = save
        file.commit()

    @pаtch_tpool_execute
    def test_crop_image_not_found(self, mocked_execute):

        file_uuid = str(uuid.uuid1())
        file, is_new = self.project.add_file(
            uuid=file_uuid,
            file_type="image",
            name=f"name",
            filename=f"image",
            extension=".png",
            size=32*1024,
        )

        self.assertTrue(is_new)

        image = self._add_image((300, 300), file)

        save = file.path
        file.path = "/some/nonexisting/path"
        file.commit()
        url = url_for("get_cropped_file", file_id=file.id,
            resolution=300, crop_box="0x0x1x1")
        response = self.get(url, status_code=404)
        response.close()

        file.path = save
        file.commit()

    @pаtch_tpool_execute
    def test_crop_image(self, mocked_execute):

        file_uuid = str(uuid.uuid1())
        file, is_new = self.project.add_file(
            uuid=file_uuid,
            file_type="image",
            name=f"name",
            filename=f"image",
            extension=".png",
            size=32*1024,
        )

        self.assertTrue(is_new)

        image = self._add_image((300, 300), file)

        for box in [(0,0,1,1), (0,0,1/2,1/2), (1/2,1/2, 1, 1), (1/3,1/2,3/4, 1), ]:
            url = url_for("get_cropped_file", file_id=file.id,
                resolution=300, crop_box="x".join(map(str, box)))
            response = self.get(url)

            self.assertFalse(response.is_json)

            returned_im = _im_from_bytes(response.data)
            response.close()

            crop = _crop(image, BoundingBox(*box))
            self.assertEqual(crop.shape, returned_im.shape)
            self._compare_images(crop, returned_im)

    def test_automatic_thumbnail_generation(self):

        img_size = (4000, 6000, 3)
        self._create_dummy_image("image.jpg", size=img_size)

        file_uuid = str(uuid.uuid1())
        file, is_new = self.project.add_file(
            uuid=file_uuid,
            file_type="image",
            name=f"name",
            filename=f"image",
            extension=".jpg",
            size=32*1024,
        )

        self.assertTrue(is_new)

        self.assertEqual(1, self.project.files.count())

        self.assertTrue(os.path.exists(file.absolute_path))

        temp_folder = os.path.join(self.project.root_folder, "temp")
        for max_width, max_height in [(200, 200), (2000, 800)]:
            img_path = os.path.join(temp_folder, file_uuid + "_" + str(max_width) + "_" + str(max_height) + ".jpg")

            self.assertTrue(os.path.exists(img_path))

            with Image.open(img_path) as img:
                width, height = img.size

            self.assertTrue(width == max_width or height == max_height)
            self.assertLessEqual(width, max_width)
            self.assertLessEqual(height, max_height)
            self.assertLessEqual(abs(img_size[1] / img_size[0] - width / height), 0.1)

def _im_from_bytes(data: bytes) -> np.ndarray:
    return np.asarray(Image.open(io.BytesIO(data)))


def _resize(image: np.ndarray, size: int) -> np.ndarray:
    return np.asarray(Image.fromarray(image).resize((size, size)))


def _crop(image: np.ndarray, box: BoundingBox) -> np.ndarray:
    h, w, *c = image.shape

    x0, y0 = int(w * box.x), int(h * box.y)
    crop_w, crop_h = int(w * box.w), int(h * box.h)
    x1, y1 = x0 + crop_w, y0 + crop_h

    return image[y0:y1, x0:x1]