base.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import os
  2. import shutil
  3. import unittest
  4. import eventlet
  5. import typing as T
  6. from unittest import mock
  7. from pathlib import Path
  8. from pycs import app
  9. from pycs import db
  10. from pycs import settings
  11. from pycs.frontend.WebServer import WebServer
  12. from pycs.database.Model import Model
  13. from pycs.database.LabelProvider import LabelProvider
  14. from pycs.util.PipelineCache import PipelineCache
  15. server = None
  16. def pаtch_tpool_execute(test_func):
  17. def call_func(func, *args, **kwargs):
  18. return func(*args, **kwargs)
  19. decorator = mock.patch("eventlet.tpool.execute",
  20. side_effect=call_func)
  21. return decorator(test_func)
  22. class BaseTestCase(unittest.TestCase):
  23. _sleep_time = 0.2
  24. server = None
  25. DB_FILE = Path.cwd() / "test.sqlite"
  26. @classmethod
  27. def setUpClass(cls, discovery: bool = False):
  28. global server
  29. app.config["TESTING"] = True
  30. app.config["WTF_CSRF_ENABLED"] = False
  31. app.config["DEBUG"] = False
  32. app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{cls.DB_FILE}"
  33. if server is None:
  34. settings["pipeline_cache_time"] = 2
  35. server = WebServer(app, settings, discovery)
  36. if cls.server is None:
  37. cls.server = server
  38. db.create_all()
  39. # # run discovery modules manually
  40. # Model.discover("models/")
  41. # LabelProvider.discover("labels/")
  42. cls.server.pipelines.start()
  43. def wait_for_bg_jobs(self, raise_errors=True):
  44. # wait for JobRunner jobs to finish
  45. while True:
  46. ready = True
  47. for job in self.server.jobs.list():
  48. if job.finished is None:
  49. print(f"{job} is not finished!")
  50. ready = False
  51. break
  52. if raise_errors:
  53. self.assertTrue(job.exception is None,
  54. job.exception)
  55. if ready:
  56. break
  57. eventlet.sleep(self._sleep_time)
  58. # wait for PipelineCache to finish
  59. while not self.server.pipelines.is_empty:
  60. eventlet.sleep(self._sleep_time)
  61. @classmethod
  62. def tearDownClass(cls):
  63. super().tearDownClass()
  64. cls.server.pipelines.shutdown()
  65. def setUp(self):
  66. self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
  67. db.create_all()
  68. self.client = app.test_client()
  69. self.context = app.test_request_context()
  70. self.context.push()
  71. self.setupModels()
  72. def setupModels(self):
  73. pass
  74. def tearDown(self):
  75. self.wait_for_bg_jobs(raise_errors=False)
  76. self.context.pop()
  77. if os.path.exists(self.projects_dir):
  78. shutil.rmtree(self.projects_dir)
  79. db.drop_all()
  80. def _do_request(self, request_func, *args, status_code=200, **kwargs):
  81. response = request_func(*args, follow_redirects=True, **kwargs)
  82. try:
  83. content = response.get_data().decode()
  84. except UnicodeDecodeError:
  85. content = response.get_data()
  86. self.assertEqual(response.status_code, status_code, content)
  87. return response
  88. def post(self,
  89. url: str,
  90. *,
  91. status_code: int = 200,
  92. data: T.Optional[dict] = None,
  93. json: T.Optional[dict] = None,
  94. **kwargs):
  95. return self._do_request(
  96. self.client.post,
  97. url,
  98. status_code=status_code,
  99. json=json,
  100. data=data,
  101. **kwargs
  102. )
  103. def get(self,
  104. url: str,
  105. *,
  106. status_code: int = 200,
  107. data: T.Optional[dict] = None,
  108. json: T.Optional[dict] = None):
  109. return self._do_request(
  110. self.client.get,
  111. url,
  112. status_code=status_code,
  113. json=json,
  114. data=data,
  115. )
  116. def wait_for_coroutines(self):
  117. eventlet.sleep(self._sleep_time)