6
0

base.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import eventlet
  2. import logging
  3. import os
  4. import shutil
  5. import typing as T
  6. import unittest
  7. import base64
  8. from flask_htpasswd import HtPasswdAuth
  9. from pathlib import Path
  10. from unittest import mock
  11. from pycs import app
  12. from pycs import db
  13. from pycs import settings
  14. from pycs.database.LabelProvider import LabelProvider
  15. from pycs.database.Model import Model
  16. from pycs.frontend.WebServer import WebServer
  17. from pycs.util.PipelineCache import PipelineCache
  18. server = None
  19. def pаtch_tpool_execute(test_func):
  20. def call_func(func, *args, **kwargs):
  21. return func(*args, **kwargs)
  22. decorator = mock.patch("eventlet.tpool.execute",
  23. side_effect=call_func)
  24. return decorator(test_func)
  25. class BaseTestCase(unittest.TestCase):
  26. _sleep_time = 0.2
  27. server = None
  28. DB_FILE = Path.cwd() / "test.sqlite"
  29. @classmethod
  30. def setUpClass(cls, discovery: bool = False):
  31. global server
  32. app.logger.setLevel(logging.CRITICAL)
  33. app.config["TESTING"] = True
  34. app.config["WTF_CSRF_ENABLED"] = False
  35. app.config["DEBUG"] = False
  36. app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{cls.DB_FILE}"
  37. # Set dummy password protection.
  38. # This allows authentication with the credentials user:password.
  39. app.config['FLASK_HTPASSWD_PATH'] = '.test-htpasswd'
  40. if not os.path.isfile(app.config['FLASK_HTPASSWD_PATH']):
  41. with open(app.config['FLASK_HTPASSWD_PATH'], 'w') as f:
  42. f.write('user:$apr1$fmi16nrq$3C4MfxW3ChrUNjSLLTB3x.')
  43. htpasswd = HtPasswdAuth(app)
  44. if server is None:
  45. settings["pipeline_cache_time"] = 2
  46. server = WebServer(app, htpasswd, settings, discovery)
  47. if cls.server is None:
  48. cls.server = server
  49. db.create_all()
  50. # # run discovery modules manually
  51. # Model.discover("models/")
  52. # LabelProvider.discover("labels/")
  53. cls.server.pipelines.start()
  54. def wait_for_bg_jobs(self, raise_errors=True):
  55. # wait for JobRunner jobs to finish
  56. while True:
  57. ready = True
  58. for job in self.server.jobs.list():
  59. if job.finished is None:
  60. app.logger.debug(f"{job} is not finished!")
  61. ready = False
  62. break
  63. if raise_errors:
  64. self.assertTrue(job.exception is None,
  65. job.exception)
  66. if ready:
  67. break
  68. self.wait_for_coroutines()
  69. # wait for PipelineCache to finish
  70. while not self.server.pipelines.is_empty:
  71. self.wait_for_coroutines()
  72. def wait_for_coroutines(self):
  73. eventlet.sleep(self._sleep_time)
  74. @classmethod
  75. def tearDownClass(cls):
  76. super().tearDownClass()
  77. cls.server.pipelines.shutdown()
  78. while not cls.server.pipelines.is_empty:
  79. eventlet.sleep(cls._sleep_time)
  80. def setUp(self):
  81. self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
  82. db.create_all()
  83. credentials = base64.b64encode("user:password".encode()).decode()
  84. self.headers = { 'Authorization' : 'Basic %s' % credentials }
  85. self.client = app.test_client()
  86. self.context = app.test_request_context()
  87. self.context.push()
  88. self.setupModels()
  89. def setupModels(self):
  90. pass
  91. def tearDown(self):
  92. self.wait_for_bg_jobs(raise_errors=False)
  93. self.context.pop()
  94. if os.path.exists(self.projects_dir):
  95. shutil.rmtree(self.projects_dir)
  96. db.drop_all()
  97. def _do_request(self, request_func, *args, status_code=200, **kwargs):
  98. response = request_func(*args, follow_redirects=True, **kwargs)
  99. try:
  100. content = response.get_data().decode()
  101. except UnicodeDecodeError:
  102. content = response.get_data()
  103. self.assertEqual(response.status_code, status_code, content)
  104. return response
  105. def post(self,
  106. url: str,
  107. *,
  108. status_code: int = 200,
  109. data: T.Optional[dict] = None,
  110. json: T.Optional[dict] = None,
  111. **kwargs):
  112. return self._do_request(
  113. self.client.post,
  114. url,
  115. status_code=status_code,
  116. json=json,
  117. data=data,
  118. headers=self.headers,
  119. **kwargs
  120. )
  121. def get(self,
  122. url: str,
  123. *,
  124. status_code: int = 200,
  125. data: T.Optional[dict] = None,
  126. json: T.Optional[dict] = None):
  127. return self._do_request(
  128. self.client.get,
  129. url,
  130. status_code=status_code,
  131. json=json,
  132. data=data,
  133. headers=self.headers
  134. )