__init__.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import tempfile
  2. from flask import url_for
  3. from pycs.database.LabelProvider import LabelProvider
  4. from pycs.database.Model import Model
  5. from tests.base import BaseTestCase
  6. from tests.client.file_tests import *
  7. from tests.client.label_tests import *
  8. from tests.client.pipeline_tests import *
  9. from tests.client.project_tests import *
  10. from tests.client.result_tests import *
  11. class FolderInformationTest(BaseTestCase):
  12. def _check(self, url, folder, content_should):
  13. response = self.post(url, json=dict(folder=folder))
  14. self.assertTrue(response.is_json)
  15. self.assertDictEqual(content_should, response.json)
  16. def test_folder_information(self):
  17. url = url_for("folder_information")
  18. self.post(url, json=dict(), status_code=400)
  19. with tempfile.TemporaryDirectory() as folder:
  20. self._check(url, "/not_existing/folder",
  21. dict(exists=False))
  22. for i in range(10):
  23. self._check(url, folder,
  24. dict(exists=True, count=i))
  25. f = tempfile.NamedTemporaryFile(dir=folder, delete=False)
  26. class ListModelsAndLabelProviders(BaseTestCase):
  27. def test_list_models(self):
  28. self.assertEqual(0, Model.query.count())
  29. url = url_for("list_models")
  30. response = self.get(url)
  31. self.assertTrue(response.is_json)
  32. self.assertEqual([], response.json)
  33. models = {}
  34. n = 5
  35. for i, _ in enumerate(range(n), 1):
  36. model = Model.new(
  37. commit=False,
  38. name=f"TestModel{i}",
  39. description="Model for a test case #{i}",
  40. root_folder=f"models/fixed_model{i}",
  41. )
  42. model.supports = ["labeled-image"]
  43. model.flush()
  44. models[model.id] = model
  45. model.commit()
  46. self.assertEqual(n, Model.query.count())
  47. response = self.get(url)
  48. self.assertTrue(response.is_json)
  49. content = response.json
  50. self.assertEqual(n, len(response.json))
  51. for entry in content:
  52. model = models[entry["id"]]
  53. self.assertDictEqual(model.serialize(), entry)
  54. def test_list_label_providers(self):
  55. self.assertEqual(0, LabelProvider.query.count())
  56. url = url_for("label_providers")
  57. response = self.get(url)
  58. self.assertTrue(response.is_json)
  59. self.assertEqual([], response.json)
  60. providers = {}
  61. n = 5
  62. for i, _ in enumerate(range(n), 1):
  63. provider = LabelProvider.new(
  64. commit=False,
  65. name=f"Testprovider{i}",
  66. description="LabelProvider for a test case #{i}",
  67. root_folder=f"providers/fixed_provider{i}",
  68. configuration_file=f"providers/fixed_provider{i}/config.json",
  69. )
  70. provider.supports = ["labeled-image"]
  71. provider.flush()
  72. providers[provider.id] = provider
  73. provider.commit()
  74. self.assertEqual(n, LabelProvider.query.count())
  75. response = self.get(url)
  76. self.assertTrue(response.is_json)
  77. content = response.json
  78. self.assertEqual(n, len(response.json))
  79. for entry in content:
  80. provider = providers[entry["id"]]
  81. self.assertDictEqual(provider.serialize(), entry)