123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- import eventlet
- import logging
- import os
- import shutil
- import typing as T
- import unittest
- from pathlib import Path
- from unittest import mock
- from pycs import app
- from pycs import db
- from pycs import settings
- from pycs.database.LabelProvider import LabelProvider
- from pycs.database.Model import Model
- from pycs.frontend.WebServer import WebServer
- from pycs.util.PipelineCache import PipelineCache
- server = None
- def pаtch_tpool_execute(test_func):
- def call_func(func, *args, **kwargs):
- return func(*args, **kwargs)
- decorator = mock.patch("eventlet.tpool.execute",
- side_effect=call_func)
- return decorator(test_func)
- class BaseTestCase(unittest.TestCase):
- _sleep_time = 0.2
- server = None
- DB_FILE = Path.cwd() / "test.sqlite"
- @classmethod
- def setUpClass(cls, discovery: bool = False):
- global server
- app.logger.setLevel(logging.CRITICAL)
- app.config["TESTING"] = True
- app.config["WTF_CSRF_ENABLED"] = False
- app.config["DEBUG"] = False
- app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{cls.DB_FILE}"
- if server is None:
- settings["pipeline_cache_time"] = 2
- server = WebServer(app, settings, discovery)
- if cls.server is None:
- cls.server = server
- db.create_all()
- # # run discovery modules manually
- # Model.discover("models/")
- # LabelProvider.discover("labels/")
- cls.server.pipelines.start()
- def wait_for_bg_jobs(self, raise_errors=True):
- # wait for JobRunner jobs to finish
- while True:
- ready = True
- for job in self.server.jobs.list():
- if job.finished is None:
- app.logger.debug(f"{job} is not finished!")
- ready = False
- break
- if raise_errors:
- self.assertTrue(job.exception is None,
- job.exception)
- if ready:
- break
- self.wait_for_coroutines()
- # wait for PipelineCache to finish
- while not self.server.pipelines.is_empty:
- self.wait_for_coroutines()
- def wait_for_coroutines(self):
- eventlet.sleep(self._sleep_time)
- @classmethod
- def tearDownClass(cls):
- super().tearDownClass()
- cls.server.pipelines.shutdown()
- while not cls.server.pipelines.is_empty:
- eventlet.sleep(cls._sleep_time)
- def setUp(self):
- self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
- db.create_all()
- self.client = app.test_client()
- self.context = app.test_request_context()
- self.context.push()
- self.setupModels()
- def setupModels(self):
- pass
- def tearDown(self):
- self.wait_for_bg_jobs(raise_errors=False)
- self.context.pop()
- if os.path.exists(self.projects_dir):
- shutil.rmtree(self.projects_dir)
- db.drop_all()
- def _do_request(self, request_func, *args, status_code=200, **kwargs):
- response = request_func(*args, follow_redirects=True, **kwargs)
- try:
- content = response.get_data().decode()
- except UnicodeDecodeError:
- content = response.get_data()
- self.assertEqual(response.status_code, status_code, content)
- return response
- def post(self,
- url: str,
- *,
- status_code: int = 200,
- data: T.Optional[dict] = None,
- json: T.Optional[dict] = None,
- **kwargs):
- return self._do_request(
- self.client.post,
- url,
- status_code=status_code,
- json=json,
- data=data,
- **kwargs
- )
- def get(self,
- url: str,
- *,
- status_code: int = 200,
- data: T.Optional[dict] = None,
- json: T.Optional[dict] = None):
- return self._do_request(
- self.client.get,
- url,
- status_code=status_code,
- json=json,
- data=data,
- )
|