base.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import os
  2. import shutil
  3. import unittest
  4. import eventlet
  5. import typing as T
  6. from unittest import mock
  7. from pycs import app
  8. from pycs import db
  9. from pycs import settings
  10. from pycs.frontend.WebServer import WebServer
  11. from pycs.database.Model import Model
  12. from pycs.database.LabelProvider import LabelProvider
  13. from pycs.util.PipelineCache import PipelineCache
  14. server = None
  15. def pаtch_tpool_execute(test_func):
  16. def call_func(func, *args, **kwargs):
  17. return func(*args, **kwargs)
  18. decorator = mock.patch("eventlet.tpool.execute",
  19. side_effect=call_func)
  20. return decorator(test_func)
  21. class BaseTestCase(unittest.TestCase):
  22. _sleep_time = 0.2
  23. server = None
  24. @classmethod
  25. def setUpClass(cls, discovery: bool = False):
  26. global server
  27. PipelineCache.CLOSE_TIMER = 2
  28. app.config["TESTING"] = True
  29. app.config["WTF_CSRF_ENABLED"] = False
  30. app.config["DEBUG"] = False
  31. app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
  32. if server is None:
  33. server = WebServer(app, settings, discovery)
  34. if cls.server is None:
  35. cls.server = server
  36. db.create_all()
  37. # # run discovery modules manually
  38. # Model.discover("models/")
  39. # LabelProvider.discover("labels/")
  40. cls.server.pipelines.start()
  41. def wait_for_bg_jobs(self):
  42. # wait for JobRunner jobs to finish
  43. while True:
  44. ready = True
  45. for job in self.server.jobs.list():
  46. if job.finished is None:
  47. print(f"{job} is not finished!")
  48. ready = False
  49. break
  50. if ready:
  51. break
  52. eventlet.sleep(self._sleep_time)
  53. # wait for PipelineCache to finish
  54. while not self.server.pipelines.is_empty:
  55. eventlet.sleep(self._sleep_time)
  56. @classmethod
  57. def tearDownClass(cls):
  58. super().tearDownClass()
  59. cls.server.pipelines.shutdown()
  60. def setUp(self):
  61. self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
  62. db.create_all()
  63. self.client = app.test_client()
  64. self.context = app.test_request_context()
  65. self.context.push()
  66. self.setupModels()
  67. def setupModels(self):
  68. pass
  69. def tearDown(self):
  70. self.wait_for_bg_jobs()
  71. self.context.pop()
  72. if os.path.exists(self.projects_dir):
  73. shutil.rmtree(self.projects_dir)
  74. db.drop_all()
  75. def _do_request(self, request_func, *args, status_code=200, **kwargs):
  76. response = request_func(*args, follow_redirects=True, **kwargs)
  77. try:
  78. content = response.get_data().decode()
  79. except UnicodeDecodeError:
  80. content = response.get_data()
  81. self.assertEqual(response.status_code, status_code, content)
  82. return response
  83. def post(self,
  84. url: str,
  85. *,
  86. status_code: int = 200,
  87. data: T.Optional[dict] = None,
  88. json: T.Optional[dict] = None,
  89. **kwargs):
  90. return self._do_request(
  91. self.client.post,
  92. url,
  93. status_code=status_code,
  94. json=json,
  95. data=data,
  96. **kwargs
  97. )
  98. def get(self,
  99. url: str,
  100. *,
  101. status_code: int = 200,
  102. data: T.Optional[dict] = None,
  103. json: T.Optional[dict] = None):
  104. return self._do_request(
  105. self.client.get,
  106. url,
  107. status_code=status_code,
  108. json=json,
  109. data=data,
  110. )
  111. def wait_for_coroutines(self):
  112. eventlet.sleep(self._sleep_time)