import unittest

from pycs import db
from pycs.database.File import File
from pycs.database.Label import Label
from pycs.database.LabelProvider import LabelProvider
from pycs.database.Model import Model
from pycs.database.Project import Project
from pycs.database.Result import Result

from tests.base import BaseTestCase


class TestDatabase(BaseTestCase):

    def setupModels(self):
        with db.session.begin_nested():

            for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
                model = Model.new(
                    commit=False,
                    name=f"Model {i}",
                    description=f"Description for Model {i}",
                    root_folder=f"modeldir{i}",
                )
                model.supports = supports

                if i > 2:
                    continue

                provider = LabelProvider.new(
                    commit=False,
                    name=f"Label Provider {i}",
                    description=f"Description for Label Provider {i}",
                    root_folder=f"labeldir{i}",
                    configuration_file=f"labeldir{i}/configuration.json"
                )

        # projects
        models = Model.query.all()
        label_providers = LabelProvider.query.all()

        for i, model in enumerate(models, 1):
            Project.new(
                name=f'Project {i}',
                description=f'Project Description {i}',
                model=model,
                label_provider=label_providers[i-1] if i < 3 else None,
                root_folder=f'projectdir{i}',
                external_data=i==2,
                data_folder=f'datadir{i}',
            )


    def test_models(self):
        models = Model.query.all()

        # test length
        self.assertEqual(len(models), 3)

        # test insert
        for i in range(3):
            self.assertEqual(models[i].id, i + 1)
            self.assertEqual(models[i].name, f'Model {i + 1}')
            self.assertEqual(models[i].description, f'Description for Model {i + 1}')
            self.assertEqual(models[i].root_folder, f'modeldir{i + 1}')

        self.assertEqual(models[0].supports, ['labeled-image', 'fit'])
        self.assertEqual(models[1].supports, ['labeled-bounding-boxes'])
        self.assertEqual(models[2].supports, ['labeled-bounding-boxes'])

        # test copy
        copy, is_new = models[0].copy_to('Copied Model', 'some_other_dir')
        self.assertTrue(is_new)
        self.assertEqual(copy.id, 4)
        self.assertEqual(copy.name, 'Copied Model')
        self.assertEqual(copy.description, 'Description for Model 1')
        self.assertEqual(copy.root_folder, 'some_other_dir')
        self.assertEqual(copy.supports, ['labeled-image', 'fit'])

    def test_label_providers(self):
        label_providers = LabelProvider.query.all()

        # test length
        self.assertEqual(len(label_providers), 2)

        for i in range(2):
            self.assertEqual(label_providers[i].id, i + 1)
            self.assertEqual(label_providers[i].name, f'Label Provider {i + 1}')
            self.assertEqual(label_providers[i].description, f'Description for Label Provider {i + 1}')
            self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
            self.assertEqual(label_providers[i].configuration_file,
                f"labeldir{i + 1}/configuration.json")

    def test_projects(self):
        models = Model.query.all()
        label_providers = LabelProvider.query.all()
        projects = Project.query.all()

        # get projects
        self.assertEqual(len(projects), 3)

        # create projects
        for i, project in enumerate(projects):

            self.assertEqual(project.id, i + 1)
            self.assertEqual(project.name, f'Project {i + 1}')
            self.assertEqual(project.description, f'Project Description {i + 1}')
            self.assertEqual(project.model_id, i + 1)
            self.assertEqual(project.model.__dict__, models[i].__dict__)
            self.assertEqual(project.label_provider_id, label_providers[i].id if i < 2 else None)
            self.assertEqual(
                project.label_provider.__dict__ if project.label_provider is not None else None,
                label_providers[i].__dict__ if i < 2 else None
            )
            self.assertEqual(project.root_folder, f'projectdir{i + 1}')
            self.assertEqual(project.external_data, i == 1)
            self.assertEqual(project.data_folder, f'datadir{i + 1}')


        # remove a project
        Project.query.first().delete()

        self.assertEqual(Project.query.count(), 2)
        self.assertEqual(Project.query.first().name, 'Project 2')

        # set properties
        project = Project.query.first()

        project.name = 'Project 0'
        project.commit()
        self.assertEqual(Project.query.first().name, 'Project 0')

        project.description = 'Description 0'
        project.commit()
        self.assertEqual(Project.query.first().description, 'Description 0')



if __name__ == '__main__':
    unittest.main()