import eventlet import logging import os import shutil import typing as T import unittest import base64 from flask_htpasswd import HtPasswdAuth from pathlib import Path from unittest import mock from pycs import app from pycs import db from pycs import settings from pycs.database.LabelProvider import LabelProvider from pycs.database.Model import Model from pycs.frontend.WebServer import WebServer from pycs.util.PipelineCache import PipelineCache server = None def pаtch_tpool_execute(test_func): def call_func(func, *args, **kwargs): return func(*args, **kwargs) decorator = mock.patch("eventlet.tpool.execute", side_effect=call_func) return decorator(test_func) class BaseTestCase(unittest.TestCase): _sleep_time = 0.2 server = None DB_FILE = Path.cwd() / "test.sqlite" @classmethod def setUpClass(cls, discovery: bool = False): global server app.logger.setLevel(logging.CRITICAL) app.config["TESTING"] = True app.config["WTF_CSRF_ENABLED"] = False app.config["DEBUG"] = False app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{cls.DB_FILE}" # Set dummy password protection. # This allows authentication with the credentials user:password. app.config['FLASK_HTPASSWD_PATH'] = '.test-htpasswd' if not os.path.isfile(app.config['FLASK_HTPASSWD_PATH']): with open(app.config['FLASK_HTPASSWD_PATH'], 'w') as f: f.write('user:$apr1$fmi16nrq$3C4MfxW3ChrUNjSLLTB3x.') htpasswd = HtPasswdAuth(app) if server is None: settings["pipeline_cache_time"] = 2 server = WebServer(app, htpasswd, settings, discovery) if cls.server is None: cls.server = server db.create_all() # # run discovery modules manually # Model.discover("models/") # LabelProvider.discover("labels/") cls.server.pipelines.start() def wait_for_bg_jobs(self, raise_errors=True): # wait for JobRunner jobs to finish while True: ready = True for job in self.server.jobs.list(): if job.finished is None: app.logger.debug(f"{job} is not finished!") ready = False break if raise_errors: self.assertTrue(job.exception is None, job.exception) if ready: break self.wait_for_coroutines() # wait for PipelineCache to finish while not self.server.pipelines.is_empty: self.wait_for_coroutines() def wait_for_coroutines(self): eventlet.sleep(self._sleep_time) @classmethod def tearDownClass(cls): super().tearDownClass() cls.server.pipelines.shutdown() while not cls.server.pipelines.is_empty: eventlet.sleep(cls._sleep_time) def setUp(self): self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects" db.create_all() credentials = base64.b64encode("user:password".encode()).decode() self.headers = { 'Authorization' : 'Basic %s' % credentials } self.client = app.test_client() self.context = app.test_request_context() self.context.push() self.setupModels() def setupModels(self): pass def tearDown(self): self.wait_for_bg_jobs(raise_errors=False) self.context.pop() if os.path.exists(self.projects_dir): shutil.rmtree(self.projects_dir) db.drop_all() def _do_request(self, request_func, *args, status_code=200, **kwargs): response = request_func(*args, follow_redirects=True, **kwargs) try: content = response.get_data().decode() except UnicodeDecodeError: content = response.get_data() self.assertEqual(response.status_code, status_code, content) return response def post(self, url: str, *, status_code: int = 200, data: T.Optional[dict] = None, json: T.Optional[dict] = None, **kwargs): return self._do_request( self.client.post, url, status_code=status_code, json=json, data=data, headers=self.headers, **kwargs ) def get(self, url: str, *, status_code: int = 200, data: T.Optional[dict] = None, json: T.Optional[dict] = None): return self._do_request( self.client.get, url, status_code=status_code, json=json, data=data, headers=self.headers )