import uuid

from flask import url_for

from pycs.database.Label import Label
from pycs.database.Result import Result

from tests.client.file_tests import _BaseFileTests


class _BaseResultTests(_BaseFileTests):

    def setupModels(self):
        super().setupModels()

        file_uuid = str(uuid.uuid1())
        self.file, is_new = self.project.add_file(
            uuid=file_uuid,
            file_type="image",
            name=f"name",
            filename=f"image",
            extension=".jpg",
            size=32*1024,
        )
        assert is_new, "the created file should be new!"

class ResultCreationTests(_BaseResultTests):

    def test_missing_file(self):
        url = url_for("create_result", file_id=4242)
        self.post(url, status_code=404)

    def test_missing_flags(self):
        url = url_for("create_result", file_id=self.file.id)


        bad_data = [
            None, # no request data at all
            dict(), # type missing
            dict(type="something"), # should be "labeled-image" or "bounding-box"
            dict(type="labeled_image"), # should be with "-"
            dict(type="bounding_box"), # should be with "-"
            dict(type="labeled-image"), # label is missing
            dict(type="bounding-box"), # data is missing
        ]

        for data in bad_data:
            self.assertEqual(0, Result.query.count())
            self.post(url, status_code=400, json=data)
            self.assertEqual(0, Result.query.count())

    def test_file_label(self):
        url = url_for("create_result", file_id=self.file.id)

        label, is_new = self.project.create_label(name="label", reference="some_label")
        self.assertTrue(is_new)

        self.assertEqual(0, Result.query.count())
        self.post(url, json=dict(type="labeled-image", label=label.id))
        self.assertEqual(1, Result.query.count())

        result = Result.query.one_or_none()

        self.assertIsNotNone(result)
        self.assertEqual("user", result.origin)
        self.assertEqual(self.file.id, result.file_id)
        self.assertEqual(label.id, result.label_id)
        self.assertEqual(label.name, result.label.name)
        self.assertIsNone(result.data_encoded)
        self.assertIsNone(result.data)

    def test_bounding_box(self):
        url = url_for("create_result", file_id=self.file.id)

        self.assertEqual(0, Result.query.count())
        box = dict(x=0, y=0.5, w=1/3, h=1/4)
        self.post(url, json=dict(type="bounding-box", data=box))
        self.assertEqual(1, Result.query.count())

        result = Result.query.one_or_none()

        self.assertIsNotNone(result)
        self.assertEqual("user", result.origin)
        self.assertEqual(self.file.id, result.file_id)
        self.assertIsNotNone(result.data_encoded)
        self.assertDictEqual(box, result.data)

        self.assertIsNone(result.label_id)


class ResultGettingTests(_BaseResultTests):

    def test_missing_file(self):
        url = url_for("get_results", file_id=4242)
        self.get(url, status_code=404)

    def test_getting_of_results(self):

        n = 5

        self.assertEqual(0, Result.query.count())
        results = {}
        for i in range(n):
            box = dict(x=0, y=0, w=0.9, h=1.0)
            res = self.file.create_result("user", "bounding-box", data=box)
            results[res.id] = res


        self.assertEqual(5, Result.query.count())

        file_uuid = str(uuid.uuid1())
        another_file, is_new = self.project.add_file(
            uuid=file_uuid,
            file_type="image",
            name=f"name2",
            filename=f"image2",
            extension=".jpg",
            size=32*1024,
        )
        self.assertTrue(is_new)

        for i in range(n):
            box = dict(x=0, y=0, w=0.9, h=1.0)
            another_file.create_result("user", "bounding-box", origin_user="dummy_username", data=box)

        self.assertEqual(10, Result.query.count())

        url = url_for("get_results", file_id=self.file.id)

        response = self.get(url)
        self.assertTrue(response.is_json)

        content = response.json

        self.assertEqual(5, len(content))

        for entry in content:
            res = results[entry["id"]]

            self.assertDictEqual(res.serialize(), entry)


class ResultEditingTests(_BaseResultTests):

    def test_edit_data(self):

        self.post(url_for("edit_result_data", result_id=4242), status_code=404)

        box0 = dict(x=0, y=0, w=0.9, h=1.0)
        result = self.file.create_result("pipeline", "bounding-box", data=box0)

        self.assertDictEqual(box0, result.data)
        self.assertEqual("pipeline", result.origin)


        url = url_for("edit_result_data", result_id=result.id)

        for data in [None, dict()]:
            self.post(url, status_code=400, json=data)

        box1 = dict(x=0, y=0, w=0.9, h=1.0)
        self.post(url, json=dict(data=box1))

        self.assertDictEqual(box1, result.data)
        self.assertEqual("user", result.origin)

    def test_edit_label(self):
        self.post(url_for("edit_result_label", result_id=4242), status_code=404)

        label1, is_new = self.project.create_label(name="label1", reference="some_label1")
        self.assertTrue(is_new)

        label2, is_new = self.project.create_label(name="label2", reference="some_label2")
        self.assertTrue(is_new)

        for result_type in ["labeled-image", "bounding-box"]:

            result = self.file.create_result("pipeline", result_type, label=label1)

            self.assertEqual(label1.id, result.label_id)
            self.assertEqual(label1.name, result.label.name)
            self.assertEqual("pipeline", result.origin)

            url = url_for("edit_result_label", result_id=result.id)

            for data in [None, dict()]:
                self.post(url, status_code=400, json=data)

            if result_type == "labeled-image":
                self.post(url, status_code=400, json=dict(label=None))

            self.post(url, json=dict(label=label2.id))

            self.assertEqual(label2.id, result.label_id)
            self.assertEqual(label2.name, result.label.name)
            self.assertEqual("user", result.origin)


    def test_unset_label_of_bounding_box(self):
        label, is_new = self.project.create_label(name="label", reference="some_label")
        self.assertTrue(is_new)

        result = self.file.create_result("pipeline", "bounding-box", label=label)

        self.assertEqual(label.id, result.label_id)
        self.assertEqual(label.name, result.label.name)
        self.assertEqual("pipeline", result.origin)

        url = url_for("edit_result_label", result_id=result.id)

        self.post(url, json=dict(label=None))

        self.assertEqual(None, result.label_id)
        self.assertEqual("user", result.origin)

    def test_confirm_result(self):
        self.post(url_for("confirm_result", result_id=4242), status_code=404)

        label, is_new = self.project.create_label(name="label", reference="some_label1")

        for result_type in ["labeled-image", "bounding-box"]:
            result = self.file.create_result("pipeline", result_type, label=label)
            url = url_for("confirm_result", result_id=result.id)


            for data in [None, dict(), dict(confirm=False)]:
                self.post(url, status_code=400, json=data)

            self.assertEqual("pipeline", result.origin)
            self.post(url, json=dict(confirm=True))
            self.assertEqual("user", result.origin)

class ResultRemovalTests(_BaseResultTests):

    def test_remove_result(self):
        self.post(url_for("remove_result", result_id=4242), status_code=404)

        label, is_new = self.project.create_label(name="label", reference="some_label1")

        self.assertEqual(0, Result.query.count())
        for result_type in ["labeled-image", "bounding-box"]:
            result = self.file.create_result("pipeline", result_type, label=label)
            self.assertEqual(1, Result.query.count())
            self.assertEqual(1, self.file.results.count())

            url = url_for("remove_result", result_id=result.id)

            for data in [None, dict(), dict(remove=False)]:
                self.post(url, status_code=400, json=data)

            self.assertEqual(1, Result.query.count())
            self.assertEqual(1, self.file.results.count())

            self.post(url, json=dict(remove=True))

            self.assertEqual(0, Result.query.count())
            self.assertEqual(0, self.file.results.count())

        self.assertEqual(0, Result.query.count())

    def test_reset_file_results(self):

        self.post(url_for("reset_results", file_id=4242), status_code=404)

        label, is_new = self.project.create_label(name="label", reference="some_label1")

        self.assertEqual(0, Result.query.count())

        n = 5
        for result_type in ["labeled-image", "bounding-box"]:
            for i in range(n):
                self.file.create_result("pipeline", result_type, label=label)

            url = url_for("reset_results", file_id=self.file.id)

            for data in [None, dict(), dict(reset=False)]:
                self.post(url, status_code=400, json=data)

            self.assertEqual(n, Result.query.count())
            self.assertEqual(n, self.file.results.count())
            self.post(url, json=dict(reset=True))

            self.assertEqual(0, Result.query.count())
            self.assertEqual(0, self.file.results.count())