from flask import url_for


from pycs import db
from pycs.database.Label import Label
from pycs.database.Model import Model
from pycs.database.Project import Project

from tests.client.project_tests import _BaseProjectTests


class _BaseLabelTests(_BaseProjectTests):

    def setupModels(self):
        super().setupModels()

        self.project = Project.new(
            name="test_project",
            description="Project for a test case",
            model=self.model,
            root_folder="project_folder",
            external_data=False,
            data_folder="project_folder/data",
        )

    def tearDown(self):
        self.project.delete()
        super().tearDown()

class LabelCreationTests(_BaseLabelTests):

    def setUp(self):
        super().setUp()
        self.url = url_for("create_label", project_id=self.project.id)

    def test_create_labels(self):
        self.assertEqual(0, self.project.labels.count())

        self.post(self.url, json=dict(), status_code=400)
        self.assertEqual(0, self.project.labels.count())

        for i in range(1, 11):
            self.post(self.url, json=dict(name=f"Label{i}"))
            self.assertEqual(self.project.labels.count(), i)
            last_label = Label.query.order_by(Label.id.desc()).first()
            self.assertEqual(last_label.name, f"Label{i}")

    def test_create_label_with_same_name(self):
        self.assertEqual(0, self.project.labels.count())

        name = "SameName"
        for i in range(1, 11):
            self.post(self.url, json=dict(name=name))
            self.assertEqual(self.project.labels.count(), i)
            last_label = Label.query.order_by(Label.id.desc()).first()
            self.assertEqual(last_label.name, name)

    def test_create_labels_with_reference(self):
        self.assertEqual(0, self.project.labels.count())

        for i in range(1, 11):
            self.post(self.url, json=dict(name=f"Label{i}", reference=f"label_ref{i}"))
            self.assertEqual(self.project.labels.count(), i)
            last_label = Label.query.order_by(Label.id.desc()).first()
            self.assertEqual(last_label.name, f"Label{i}")
            self.assertEqual(last_label.reference, f"label_ref{i}")


    def test_create_labels_with_same_reference(self):
        self.assertEqual(0, self.project.labels.count())

        ref = "same_ref"
        self.post(self.url, json=dict(name=f"Label", reference=ref))
        last_label = Label.query.order_by(Label.id.desc()).first()
        self.assertEqual(last_label.name, f"Label")
        self.assertEqual(last_label.reference, ref)

        for i in range(2, 11):
            self.post(self.url, status_code=400, json=dict(name=f"Label{i}", reference=ref))
            self.assertEqual(self.project.labels.count(), 1)
            last_label = Label.query.order_by(Label.id.desc()).first()
            self.assertEqual(last_label.name, f"Label")
            self.assertEqual(last_label.reference, ref)


class LabelListTests(_BaseLabelTests):

    def setupModels(self):
        super().setupModels()

        for i in range(1, 11):
            parent, is_new = self.project.create_label(
                name=f"Label{i}",
                reference=f"label{i}",
                hierarchy_level="Level1",
            )
            self.assertTrue(is_new)

            for j in range(1, 4):
                label, is_new = self.project.create_label(
                    name=f"Label{i}_{j}",
                    reference=f"label{i}_{j}",
                    hierarchy_level="Level2",
                    parent=parent)
                self.assertTrue(is_new)

                for k in range(1, 4):
                    final_label, is_new = self.project.create_label(
                        name=f"Label{i}_{j}_{k}",
                        reference=f"label{i}_{j}_{k}",
                        parent=label)
                    self.assertTrue(is_new)

    def test_list_labels(self):
        self.get(url_for("list_labels", project_id=4242), status_code=404)

        response = self.get(url_for("list_labels", project_id=self.project.id))
        self.assertTrue(response.is_json)
        content = response.json

        labels = {l.id: l for l in self.project.labels.all()}

        self.assertEqual(len(labels), len(content))

        for entry in content:
            label = labels[entry["id"]]
            self.assertDictEqual(label.serialize(), entry)

    def test_list_label_tree(self):
        self.get(url_for("list_label_tree", project_id=4242), status_code=404)

        response = self.get(url_for("list_label_tree", project_id=self.project.id))
        self.assertTrue(response.is_json)
        content = response.json

        root_labels = {l.id: l for l in Label.query.filter(
            Label.project_id==self.project.id, Label.parent_id==None).all()}

        self.assertEqual(len(root_labels), len(content))

        Label.serialize_only += ("children",)
        for entry in content:
            label = root_labels[entry["id"]]
            self.assertDictEqual(label.serialize(), entry)



class LabelRemovalTests(_BaseLabelTests):

    def url(self, label_id):
        return url_for("remove_label",
            project_id=self.project.id, label_id=label_id)

    def setupModels(self):
        super().setupModels()
        self.N = 10
        self.labels = []
        with db.session.begin_nested():
            for i in range(1, self.N+1):
                label, is_new = self.project.create_label(
                    commit=False,
                    name=f"Label{i}",
                    reference=f"Label{i}",
                )
                assert is_new, "labels should be new"
                label.flush()
                self.labels.append(label)

    def test_remove_label_without_flag(self):
        self.assertEqual(self.N, self.project.labels.count())

        self.post(self.url(self.labels[0].id), status_code=400)
        self.assertEqual(self.N, self.project.labels.count())

        self.post(self.url(self.labels[0].id), status_code=400, json=dict(remove=False))
        self.assertEqual(self.N, self.project.labels.count())


    def test_remove_label(self):
        self.assertEqual(self.N, self.project.labels.count())

        self.post(self.url(self.labels[0].id), json=dict(remove=True))
        self.assertEqual(self.N-1, self.project.labels.count())

    def test_remove_non_existing_label(self):
        self.assertEqual(self.N, self.project.labels.count())

        self.post(self.url(424242), json=dict(remove=True), status_code=404)
        self.assertEqual(self.N, self.project.labels.count())


class LabelTreeRemovalTests(_BaseLabelTests):

    def url(self, label_id):
        return url_for("remove_label",
            project_id=self.project.id, label_id=label_id)

    def setupModels(self):
        super().setupModels()
        self.n_roots = 5
        self.leafs_per_root = 3
        self.n_leafs = self.n_roots * self.leafs_per_root
        self.N = self.n_roots + self.n_leafs

        with db.session.begin_nested():
            for i in range(1, self.n_roots+1):
                root_label, is_new = self.project.create_label(
                    commit=False,
                    name=f"Label{i}",
                    reference=f"Label{i}",
                )
                root_label.flush()

                for j in range(1, self.leafs_per_root+1):
                    label, is_new = self.project.create_label(
                        commit=False,
                        name=f"Label{i}_{j}",
                        reference=f"Label{i}_{j}",
                        parent=root_label
                    )
                    label.flush()


    @property
    def root_labels(self):
        return self.project.labels.filter(Label.parent_id == None)

    @property
    def leaf_labels(self):
        return self.project.labels.filter(Label.parent_id != None)

    def test_remove_root_label(self):
        self.assertEqual(self.N, self.project.labels.count())
        self.assertEqual(self.n_roots, self.root_labels.count())
        self.assertEqual(self.n_leafs, self.leaf_labels.count())

        label = self.root_labels.first()

        self.post(self.url(label.id), json=dict(remove=True))

        self.assertEqual(self.N-1, self.project.labels.count())
        self.assertEqual(self.n_roots-1+self.leafs_per_root, self.root_labels.count())
        self.assertEqual(self.n_leafs-self.leafs_per_root, self.leaf_labels.count())

    def test_remove_leaf_label(self):
        self.assertEqual(self.N, self.project.labels.count())
        self.assertEqual(self.n_roots, self.root_labels.count())
        self.assertEqual(self.n_leafs, self.leaf_labels.count())

        label = self.leaf_labels.first()
        self.post(self.url(label.id), json=dict(remove=True))

        self.assertEqual(self.N-1, self.project.labels.count())
        self.assertEqual(self.n_roots, self.root_labels.count())
        self.assertEqual(self.n_leafs-1, self.leaf_labels.count())

    def test_remove_intermediate_label(self):
        self.assertEqual(self.N, self.project.labels.count())
        self.assertEqual(self.n_roots, self.root_labels.count())
        self.assertEqual(self.n_leafs, self.leaf_labels.count())

        label = self.leaf_labels.first()
        label_parent = label.parent

        new_label, is_new = self.project.create_label(
            commit=False,
            name="New Label",
            reference="new_ref",
            parent=label,
        )

        self.assertTrue(is_new)
        self.assertEqual(label.id, new_label.parent_id)
        self.assertEqual(self.n_roots, self.root_labels.count())
        self.assertEqual(self.n_leafs + 1, self.leaf_labels.count())

        self.post(self.url(label.id), json=dict(remove=True))

        self.assertEqual(label_parent.id, new_label.parent_id)
        self.assertEqual(self.n_roots, self.root_labels.count())
        self.assertEqual(self.n_leafs, self.leaf_labels.count())

class LabelEditTests(_BaseLabelTests):


    def test_edit_name(self):
        self.assertEqual(0, self.project.labels.count())
        label = self.project.create_label(name="Label", reference="label")
        self.assertEqual(1, self.project.labels.count())
        label = self.project.labels.first()
        self.assertEqual("Label", label.name)

        self.post(self.url(label.id), json=dict(), status_code=400)

        name = "Another name"
        self.post(self.url(label.id), json=dict(name=name))

        label = self.project.labels.first()
        self.assertEqual(name, label.name)

    def test_edit_name(self):
        self.assertEqual(0, self.project.labels.count())
        label = self.project.create_label(name="Label", reference="label")
        self.assertEqual(1, self.project.labels.count())
        label = self.project.labels.first()
        self.assertEqual("Label", label.name)


        url = lambda label_id: url_for("edit_label_name",
            project_id=self.project.id, label_id=label_id)

        self.post(url(label.id), json=dict(), status_code=400)

        name = "Another name"
        self.post(url(4242), json=dict(name=name), status_code=404)
        self.post(url(label.id), json=dict(name=name))

        label = self.project.labels.first()
        self.assertEqual(name, label.name)

    def test_edit_parent(self):
        self.assertEqual(0, self.project.labels.count())
        label, _ = self.project.create_label(name="Label", reference="label")
        parent, _ = self.project.create_label(name="Label2", reference="label2")
        self.assertEqual(2, self.project.labels.count())

        label = Label.query.get(label.id)
        self.assertIsNone(label.parent_id)


        url = lambda label_id: url_for("edit_label_parent",
            project_id=self.project.id, label_id=label_id)

        self.post(url(label.id), json=dict(), status_code=400)

        self.post(url(4242), json=dict(parent=parent.id), status_code=404)
        self.post(url(label.id), json=dict(parent=parent.id))

        label = Label.query.get(label.id)
        self.assertEqual(parent.id, label.parent_id)