123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import tempfile
- from flask import url_for
- from pycs.database.LabelProvider import LabelProvider
- from pycs.database.Model import Model
- from tests.base import BaseTestCase
- from tests.client.file_tests import *
- from tests.client.label_tests import *
- from tests.client.pipeline_tests import *
- from tests.client.project_tests import *
- from tests.client.result_tests import *
- class FolderInformationTest(BaseTestCase):
- def _check(self, url, folder, content_should):
- response = self.post(url, json=dict(folder=folder))
- self.assertTrue(response.is_json)
- self.assertDictEqual(content_should, response.json)
- def test_folder_information(self):
- url = url_for("folder_information")
- self.post(url, json=dict(), status_code=400)
- with tempfile.TemporaryDirectory() as folder:
- self._check(url, "/not_existing/folder",
- dict(exists=False))
- for i in range(10):
- self._check(url, folder,
- dict(exists=True, count=i))
- f = tempfile.NamedTemporaryFile(dir=folder, delete=False)
- class ListModelsAndLabelProviders(BaseTestCase):
- def test_list_models(self):
- self.assertEqual(0, Model.query.count())
- url = url_for("list_models")
- response = self.get(url)
- self.assertTrue(response.is_json)
- self.assertEqual([], response.json)
- models = {}
- n = 5
- for i, _ in enumerate(range(n), 1):
- model = Model.new(
- commit=False,
- name=f"TestModel{i}",
- description="Model for a test case #{i}",
- root_folder=f"models/fixed_model{i}",
- )
- model.supports = ["labeled-image"]
- model.flush()
- models[model.id] = model
- model.commit()
- self.assertEqual(n, Model.query.count())
- response = self.get(url)
- self.assertTrue(response.is_json)
- content = response.json
- self.assertEqual(n, len(response.json))
- for entry in content:
- model = models[entry["id"]]
- self.assertDictEqual(model.serialize(), entry)
- def test_list_label_providers(self):
- self.assertEqual(0, LabelProvider.query.count())
- url = url_for("label_providers")
- response = self.get(url)
- self.assertTrue(response.is_json)
- self.assertEqual([], response.json)
- providers = {}
- n = 5
- for i, _ in enumerate(range(n), 1):
- provider = LabelProvider.new(
- commit=False,
- name=f"Testprovider{i}",
- description="LabelProvider for a test case #{i}",
- root_folder=f"providers/fixed_provider{i}",
- configuration_file=f"providers/fixed_provider{i}/config.json",
- )
- provider.supports = ["labeled-image"]
- provider.flush()
- providers[provider.id] = provider
- provider.commit()
- self.assertEqual(n, LabelProvider.query.count())
- response = self.get(url)
- self.assertTrue(response.is_json)
- content = response.json
- self.assertEqual(n, len(response.json))
- for entry in content:
- provider = providers[entry["id"]]
- self.assertDictEqual(provider.serialize(), entry)
|