base.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import eventlet
  2. import logging
  3. import os
  4. import shutil
  5. import typing as T
  6. import unittest
  7. from pathlib import Path
  8. from unittest import mock
  9. from pycs import app
  10. from pycs import db
  11. from pycs import settings
  12. from pycs.database.LabelProvider import LabelProvider
  13. from pycs.database.Model import Model
  14. from pycs.frontend.WebServer import WebServer
  15. from pycs.util.PipelineCache import PipelineCache
  16. server = None
  17. def pаtch_tpool_execute(test_func):
  18. def call_func(func, *args, **kwargs):
  19. return func(*args, **kwargs)
  20. decorator = mock.patch("eventlet.tpool.execute",
  21. side_effect=call_func)
  22. return decorator(test_func)
  23. class BaseTestCase(unittest.TestCase):
  24. _sleep_time = 0.2
  25. server = None
  26. DB_FILE = Path.cwd() / "test.sqlite"
  27. @classmethod
  28. def setUpClass(cls, discovery: bool = False):
  29. global server
  30. app.logger.setLevel(logging.CRITICAL)
  31. app.config["TESTING"] = True
  32. app.config["WTF_CSRF_ENABLED"] = False
  33. app.config["DEBUG"] = False
  34. app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{cls.DB_FILE}"
  35. if server is None:
  36. settings["pipeline_cache_time"] = 2
  37. server = WebServer(app, settings, discovery)
  38. if cls.server is None:
  39. cls.server = server
  40. db.create_all()
  41. # # run discovery modules manually
  42. # Model.discover("models/")
  43. # LabelProvider.discover("labels/")
  44. cls.server.pipelines.start()
  45. def wait_for_bg_jobs(self, raise_errors=True):
  46. # wait for JobRunner jobs to finish
  47. while True:
  48. ready = True
  49. for job in self.server.jobs.list():
  50. if job.finished is None:
  51. app.logger.debug(f"{job} is not finished!")
  52. ready = False
  53. break
  54. if raise_errors:
  55. self.assertTrue(job.exception is None,
  56. job.exception)
  57. if ready:
  58. break
  59. self.wait_for_coroutines()
  60. # wait for PipelineCache to finish
  61. while not self.server.pipelines.is_empty:
  62. self.wait_for_coroutines()
  63. def wait_for_coroutines(self):
  64. eventlet.sleep(self._sleep_time)
  65. @classmethod
  66. def tearDownClass(cls):
  67. super().tearDownClass()
  68. cls.server.pipelines.shutdown()
  69. while not cls.server.pipelines.is_empty:
  70. eventlet.sleep(cls._sleep_time)
  71. def setUp(self):
  72. self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
  73. db.create_all()
  74. self.client = app.test_client()
  75. self.context = app.test_request_context()
  76. self.context.push()
  77. self.setupModels()
  78. def setupModels(self):
  79. pass
  80. def tearDown(self):
  81. self.wait_for_bg_jobs(raise_errors=False)
  82. self.context.pop()
  83. if os.path.exists(self.projects_dir):
  84. shutil.rmtree(self.projects_dir)
  85. db.drop_all()
  86. def _do_request(self, request_func, *args, status_code=200, **kwargs):
  87. response = request_func(*args, follow_redirects=True, **kwargs)
  88. try:
  89. content = response.get_data().decode()
  90. except UnicodeDecodeError:
  91. content = response.get_data()
  92. self.assertEqual(response.status_code, status_code, content)
  93. return response
  94. def post(self,
  95. url: str,
  96. *,
  97. status_code: int = 200,
  98. data: T.Optional[dict] = None,
  99. json: T.Optional[dict] = None,
  100. **kwargs):
  101. return self._do_request(
  102. self.client.post,
  103. url,
  104. status_code=status_code,
  105. json=json,
  106. data=data,
  107. **kwargs
  108. )
  109. def get(self,
  110. url: str,
  111. *,
  112. status_code: int = 200,
  113. data: T.Optional[dict] = None,
  114. json: T.Optional[dict] = None):
  115. return self._do_request(
  116. self.client.get,
  117. url,
  118. status_code=status_code,
  119. json=json,
  120. data=data,
  121. )