base.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import eventlet
  2. import os
  3. import shutil
  4. import typing as T
  5. import unittest
  6. import logging
  7. from pathlib import Path
  8. from unittest import mock
  9. from pycs import app
  10. from pycs import db
  11. from pycs import settings
  12. from pycs.database.LabelProvider import LabelProvider
  13. from pycs.database.Model import Model
  14. from pycs.frontend.WebServer import WebServer
  15. from pycs.util.PipelineCache import PipelineCache
  16. server = None
  17. def pаtch_tpool_execute(test_func):
  18. def call_func(func, *args, **kwargs):
  19. return func(*args, **kwargs)
  20. decorator = mock.patch("eventlet.tpool.execute",
  21. side_effect=call_func)
  22. return decorator(test_func)
  23. class BaseTestCase(unittest.TestCase):
  24. _sleep_time = 0.2
  25. server = None
  26. DB_FILE = Path.cwd() / "test.sqlite"
  27. @classmethod
  28. def setUpClass(cls, discovery: bool = False):
  29. global server
  30. app.logger.setLevel(logging.CRITICAL)
  31. app.config["TESTING"] = True
  32. app.config["WTF_CSRF_ENABLED"] = False
  33. app.config["DEBUG"] = False
  34. app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{cls.DB_FILE}"
  35. if server is None:
  36. settings["pipeline_cache_time"] = 2
  37. server = WebServer(app, settings, discovery)
  38. if cls.server is None:
  39. cls.server = server
  40. db.create_all()
  41. # # run discovery modules manually
  42. # Model.discover("models/")
  43. # LabelProvider.discover("labels/")
  44. cls.server.pipelines.start()
  45. def wait_for_bg_jobs(self, raise_errors=True):
  46. # wait for JobRunner jobs to finish
  47. while True:
  48. ready = True
  49. for job in self.server.jobs.list():
  50. if job.finished is None:
  51. app.logger.debug(f"{job} is not finished!")
  52. ready = False
  53. break
  54. if raise_errors:
  55. self.assertTrue(job.exception is None,
  56. job.exception)
  57. if ready:
  58. break
  59. eventlet.sleep(self._sleep_time)
  60. # wait for PipelineCache to finish
  61. while not self.server.pipelines.is_empty:
  62. eventlet.sleep(self._sleep_time)
  63. @classmethod
  64. def tearDownClass(cls):
  65. super().tearDownClass()
  66. cls.server.pipelines.shutdown()
  67. def setUp(self):
  68. self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
  69. db.create_all()
  70. self.client = app.test_client()
  71. self.context = app.test_request_context()
  72. self.context.push()
  73. self.setupModels()
  74. def setupModels(self):
  75. pass
  76. def tearDown(self):
  77. self.wait_for_bg_jobs(raise_errors=False)
  78. self.context.pop()
  79. if os.path.exists(self.projects_dir):
  80. shutil.rmtree(self.projects_dir)
  81. db.drop_all()
  82. def _do_request(self, request_func, *args, status_code=200, **kwargs):
  83. response = request_func(*args, follow_redirects=True, **kwargs)
  84. try:
  85. content = response.get_data().decode()
  86. except UnicodeDecodeError:
  87. content = response.get_data()
  88. self.assertEqual(response.status_code, status_code, content)
  89. return response
  90. def post(self,
  91. url: str,
  92. *,
  93. status_code: int = 200,
  94. data: T.Optional[dict] = None,
  95. json: T.Optional[dict] = None,
  96. **kwargs):
  97. return self._do_request(
  98. self.client.post,
  99. url,
  100. status_code=status_code,
  101. json=json,
  102. data=data,
  103. **kwargs
  104. )
  105. def get(self,
  106. url: str,
  107. *,
  108. status_code: int = 200,
  109. data: T.Optional[dict] = None,
  110. json: T.Optional[dict] = None):
  111. return self._do_request(
  112. self.client.get,
  113. url,
  114. status_code=status_code,
  115. json=json,
  116. data=data,
  117. )
  118. def wait_for_coroutines(self):
  119. eventlet.sleep(self._sleep_time)