6
0

base.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import os
  2. import shutil
  3. import unittest
  4. import typing as T
  5. from pycs import app
  6. from pycs import db
  7. from pycs import settings
  8. from pycs.frontend.WebServer import WebServer
  9. from pycs.database.Model import Model
  10. from pycs.database.LabelProvider import LabelProvider
  11. server = None
  12. class BaseTestCase(unittest.TestCase):
  13. def setUp(self, discovery: bool = False):
  14. global server
  15. app.config["TESTING"] = True
  16. self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
  17. app.config["WTF_CSRF_ENABLED"] = False
  18. app.config["DEBUG"] = False
  19. app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
  20. db.create_all()
  21. self.client = app.test_client()
  22. self.context = app.test_request_context()
  23. self.context.push()
  24. # init the server once
  25. if server is None:
  26. server = WebServer(app, settings, discovery)
  27. elif discovery:
  28. # run discovery modules manually
  29. Model.discover("models/")
  30. LabelProvider.discover("labels/")
  31. self.setupModels()
  32. def setupModels(self):
  33. pass
  34. def tearDown(self):
  35. self.context.pop()
  36. if os.path.exists(self.projects_dir):
  37. shutil.rmtree(self.projects_dir)
  38. db.drop_all()
  39. def _do_request(self, request_func, *args, status_code=200, **kwargs):
  40. response = request_func(*args, follow_redirects=True, **kwargs)
  41. self.assertEqual(response.status_code, status_code, response.get_data().decode())
  42. return response
  43. def post(self,
  44. url: str,
  45. *,
  46. status_code: int = 200,
  47. content_type: T.Optional[str] = None,
  48. data: T.Optional[dict] = None,
  49. json: T.Optional[dict] = None):
  50. return self._do_request(
  51. self.client.post,
  52. url,
  53. status_code=status_code,
  54. json=json,
  55. data=data,
  56. content_type=content_type,
  57. )
  58. def get(self,
  59. url: str,
  60. *,
  61. status_code: int = 200,
  62. data: T.Optional[dict] = None,
  63. json: T.Optional[dict] = None):
  64. return self._do_request(
  65. self.client.get,
  66. url,
  67. status_code=status_code,
  68. json=json,
  69. data=data,
  70. content_type=content_type,
  71. )