6
0

__init__.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import tempfile
  2. from flask import url_for
  3. from pycs.database.LabelProvider import LabelProvider
  4. from pycs.database.Model import Model
  5. from tests.base import BaseTestCase
  6. from tests.client.file_tests import *
  7. from tests.client.label_tests import *
  8. from tests.client.pipeline_tests import *
  9. from tests.client.project_tests import *
  10. from tests.client.result_tests import *
  11. class FolderInformationTest(BaseTestCase):
  12. def _check(self, url, folder, content_should):
  13. response = self.post(url, json=dict(folder=folder))
  14. self.assertTrue(response.is_json)
  15. self.assertDictEqual(content_should, response.json)
  16. def test_folder_information(self):
  17. url = url_for("folder_information")
  18. self.post(url, json=dict(), status_code=400)
  19. with tempfile.TemporaryDirectory() as folder:
  20. self._check(url, "/not_existing/folder", dict(exists=False))
  21. for i in range(10):
  22. self._check(url, folder, dict(exists=True, count=i))
  23. tempfile.NamedTemporaryFile(dir=folder, delete=False, suffix=".jpg").close()
  24. class ListModelsAndLabelProviders(BaseTestCase):
  25. def test_list_models(self):
  26. self.assertEqual(0, Model.query.count())
  27. url = url_for("list_models")
  28. response = self.get(url)
  29. self.assertTrue(response.is_json)
  30. self.assertEqual([], response.json)
  31. models = {}
  32. n = 5
  33. for i, _ in enumerate(range(n), 1):
  34. model = Model.new(
  35. commit=False,
  36. name=f"TestModel{i}",
  37. description="Model for a test case #{i}",
  38. root_folder=f"models/fixed_model{i}",
  39. )
  40. model.supports = ["labeled-image"]
  41. model.flush()
  42. models[model.id] = model
  43. model.commit()
  44. self.assertEqual(n, Model.query.count())
  45. response = self.get(url)
  46. self.assertTrue(response.is_json)
  47. content = response.json
  48. self.assertEqual(n, len(response.json))
  49. for entry in content:
  50. model = models[entry["id"]]
  51. self.assertDictEqual(model.serialize(), entry)
  52. def test_list_label_providers(self):
  53. self.assertEqual(0, LabelProvider.query.count())
  54. url = url_for("label_providers")
  55. response = self.get(url)
  56. self.assertTrue(response.is_json)
  57. self.assertEqual([], response.json)
  58. providers = {}
  59. n = 5
  60. for i, _ in enumerate(range(n), 1):
  61. provider = LabelProvider.new(
  62. commit=False,
  63. name=f"Testprovider{i}",
  64. description="LabelProvider for a test case #{i}",
  65. root_folder=f"providers/fixed_provider{i}",
  66. configuration_file=f"providers/fixed_provider{i}/config.json",
  67. )
  68. provider.supports = ["labeled-image"]
  69. provider.flush()
  70. providers[provider.id] = provider
  71. provider.commit()
  72. self.assertEqual(n, LabelProvider.query.count())
  73. response = self.get(url)
  74. self.assertTrue(response.is_json)
  75. content = response.json
  76. self.assertEqual(n, len(response.json))
  77. for entry in content:
  78. provider = providers[entry["id"]]
  79. self.assertDictEqual(provider.serialize(), entry)