import eventlet
import logging
import os
import shutil
import typing as T
import unittest

import base64
from flask_htpasswd import HtPasswdAuth

from pathlib import Path
from unittest import mock

from pycs import app
from pycs import db
from pycs import settings
from pycs.database.LabelProvider import LabelProvider
from pycs.database.Model import Model
from pycs.frontend.WebServer import WebServer
from pycs.util.PipelineCache import PipelineCache

server = None

def pаtch_tpool_execute(test_func):

    def call_func(func, *args, **kwargs):
        return func(*args, **kwargs)

    decorator = mock.patch("eventlet.tpool.execute",
        side_effect=call_func)

    return decorator(test_func)

class BaseTestCase(unittest.TestCase):
    _sleep_time = 0.2
    server = None

    DB_FILE = Path.cwd() / "test.sqlite"


    @classmethod
    def setUpClass(cls, discovery: bool = False):
        global server
        app.logger.setLevel(logging.CRITICAL)
        app.config["TESTING"] = True
        app.config["WTF_CSRF_ENABLED"] = False
        app.config["DEBUG"] = False
        app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{cls.DB_FILE}"

        # Set dummy password protection.
        # This allows authentication with the credentials user:password.
        app.config['FLASK_HTPASSWD_PATH'] = '.test-htpasswd'
        if not os.path.isfile(app.config['FLASK_HTPASSWD_PATH']):
            with open(app.config['FLASK_HTPASSWD_PATH'], 'w') as f:
                f.write('user:$apr1$fmi16nrq$3C4MfxW3ChrUNjSLLTB3x.')
        htpasswd = HtPasswdAuth(app)

        if server is None:
            settings["pipeline_cache_time"] = 2
            server = WebServer(app, htpasswd, settings, discovery)

        if cls.server is None:
            cls.server = server
        db.create_all()

        # # run discovery modules manually
        # Model.discover("models/")
        # LabelProvider.discover("labels/")
        cls.server.pipelines.start()


    def wait_for_bg_jobs(self, raise_errors=True):

        # wait for JobRunner jobs to finish
        while True:
            ready = True
            for job in self.server.jobs.list():
                if job.finished is None:
                    app.logger.debug(f"{job} is not finished!")
                    ready = False
                    break

                if raise_errors:
                    self.assertTrue(job.exception is None,
                        job.exception)

            if ready:
                break

            self.wait_for_coroutines()

        # wait for PipelineCache to finish

        while not self.server.pipelines.is_empty:
            self.wait_for_coroutines()

    def wait_for_coroutines(self):
        eventlet.sleep(self._sleep_time)

    @classmethod
    def tearDownClass(cls):
        super().tearDownClass()
        cls.server.pipelines.shutdown()

        while not cls.server.pipelines.is_empty:
            eventlet.sleep(cls._sleep_time)



    def setUp(self):
        self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"

        db.create_all()

        credentials = base64.b64encode("user:password".encode()).decode()
        self.headers = { 'Authorization' : 'Basic %s' %  credentials }

        self.client = app.test_client()
        self.context = app.test_request_context()
        self.context.push()

        self.setupModels()

    def setupModels(self):
        pass

    def tearDown(self):
        self.wait_for_bg_jobs(raise_errors=False)

        self.context.pop()

        if os.path.exists(self.projects_dir):
            shutil.rmtree(self.projects_dir)

        db.drop_all()


    def _do_request(self, request_func, *args,  status_code=200, **kwargs):
        response = request_func(*args, follow_redirects=True, **kwargs)
        try:
            content = response.get_data().decode()
        except UnicodeDecodeError:
            content = response.get_data()

        self.assertEqual(response.status_code, status_code, content)
        return response

    def post(self,
             url: str,
             *,
             status_code: int = 200,
             data: T.Optional[dict] = None,
             json: T.Optional[dict] = None,
             **kwargs):

        return self._do_request(
            self.client.post,
            url,
            status_code=status_code,
            json=json,
            data=data,
            headers=self.headers,
            **kwargs
        )

    def get(self,
             url: str,
             *,
             status_code: int = 200,
             data: T.Optional[dict] = None,
             json: T.Optional[dict] = None):

        return self._do_request(
            self.client.get,
            url,
            status_code=status_code,
            json=json,
            data=data,
            headers=self.headers
        )