base.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import os
  2. import shutil
  3. import unittest
  4. import eventlet
  5. import typing as T
  6. from unittest import mock
  7. from pycs import app
  8. from pycs import db
  9. from pycs import settings
  10. from pycs.frontend.WebServer import WebServer
  11. from pycs.database.Model import Model
  12. from pycs.database.LabelProvider import LabelProvider
  13. server = None
  14. def pаtch_tpool_execute(test_func):
  15. def call_func(func, *args, **kwargs):
  16. return func(*args, **kwargs)
  17. decorator = mock.patch("eventlet.tpool.execute",
  18. side_effect=call_func)
  19. return decorator(test_func)
  20. class BaseTestCase(unittest.TestCase):
  21. _sleep_time = 0.2
  22. def setUp(self, discovery: bool = False):
  23. global server
  24. app.config["TESTING"] = True
  25. self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
  26. app.config["WTF_CSRF_ENABLED"] = False
  27. app.config["DEBUG"] = False
  28. app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
  29. db.create_all()
  30. self.client = app.test_client()
  31. self.context = app.test_request_context()
  32. self.context.push()
  33. # init the server once
  34. if server is None:
  35. server = WebServer(app, settings, discovery)
  36. elif discovery:
  37. # run discovery modules manually
  38. Model.discover("models/")
  39. LabelProvider.discover("labels/")
  40. self.setupModels()
  41. def setupModels(self):
  42. pass
  43. def tearDown(self):
  44. self.context.pop()
  45. if os.path.exists(self.projects_dir):
  46. shutil.rmtree(self.projects_dir)
  47. db.drop_all()
  48. def _do_request(self, request_func, *args, status_code=200, **kwargs):
  49. response = request_func(*args, follow_redirects=True, **kwargs)
  50. self.assertEqual(response.status_code, status_code, response.get_data().decode())
  51. return response
  52. def post(self,
  53. url: str,
  54. *,
  55. status_code: int = 200,
  56. data: T.Optional[dict] = None,
  57. json: T.Optional[dict] = None,
  58. **kwargs):
  59. return self._do_request(
  60. self.client.post,
  61. url,
  62. status_code=status_code,
  63. json=json,
  64. data=data,
  65. **kwargs
  66. )
  67. def get(self,
  68. url: str,
  69. *,
  70. status_code: int = 200,
  71. data: T.Optional[dict] = None,
  72. json: T.Optional[dict] = None):
  73. return self._do_request(
  74. self.client.get,
  75. url,
  76. status_code=status_code,
  77. json=json,
  78. data=data,
  79. )
  80. def wait_for_coroutines(self):
  81. eventlet.sleep(self._sleep_time)