|
@@ -2,6 +2,9 @@ import tempfile
|
|
|
|
|
|
from flask import url_for
|
|
|
|
|
|
+from pycs.database.Model import Model
|
|
|
+from pycs.database.LabelProvider import LabelProvider
|
|
|
+
|
|
|
from tests.base import BaseTestCase
|
|
|
from tests.client.file_tests import *
|
|
|
from tests.client.label_tests import *
|
|
@@ -34,3 +37,76 @@ class FolderInformationTest(BaseTestCase):
|
|
|
f = tempfile.NamedTemporaryFile(dir=folder, delete=False)
|
|
|
|
|
|
|
|
|
+class ListModelsAndLabelProviders(BaseTestCase):
|
|
|
+
|
|
|
+ def test_list_models(self):
|
|
|
+ self.assertEqual(0, Model.query.count())
|
|
|
+ url = url_for("list_models")
|
|
|
+
|
|
|
+ response = self.get(url)
|
|
|
+ self.assertTrue(response.is_json)
|
|
|
+
|
|
|
+ self.assertEqual([], response.json)
|
|
|
+
|
|
|
+ models = {}
|
|
|
+ n = 5
|
|
|
+ for i, _ in enumerate(range(n), 1):
|
|
|
+
|
|
|
+ model = Model.new(
|
|
|
+ commit=False,
|
|
|
+ name=f"TestModel{i}",
|
|
|
+ description="Model for a test case #{i}",
|
|
|
+ root_folder=f"models/fixed_model{i}",
|
|
|
+ )
|
|
|
+ model.supports = ["labeled-image"]
|
|
|
+ model.flush()
|
|
|
+
|
|
|
+ models[model.id] = model
|
|
|
+ model.commit()
|
|
|
+
|
|
|
+ self.assertEqual(n, Model.query.count())
|
|
|
+ response = self.get(url)
|
|
|
+ self.assertTrue(response.is_json)
|
|
|
+ content = response.json
|
|
|
+ self.assertEqual(n, len(response.json))
|
|
|
+
|
|
|
+ for entry in content:
|
|
|
+ model = models[entry["id"]]
|
|
|
+ self.assertDictEqual(model.serialize(), entry)
|
|
|
+
|
|
|
+
|
|
|
+ def test_list_label_providers(self):
|
|
|
+ self.assertEqual(0, LabelProvider.query.count())
|
|
|
+ url = url_for("label_providers")
|
|
|
+
|
|
|
+ response = self.get(url)
|
|
|
+ self.assertTrue(response.is_json)
|
|
|
+
|
|
|
+ self.assertEqual([], response.json)
|
|
|
+
|
|
|
+ providers = {}
|
|
|
+ n = 5
|
|
|
+ for i, _ in enumerate(range(n), 1):
|
|
|
+
|
|
|
+ provider = LabelProvider.new(
|
|
|
+ commit=False,
|
|
|
+ name=f"Testprovider{i}",
|
|
|
+ description="LabelProvider for a test case #{i}",
|
|
|
+ root_folder=f"providers/fixed_provider{i}",
|
|
|
+ configuration_file=f"providers/fixed_provider{i}/config.json",
|
|
|
+ )
|
|
|
+ provider.supports = ["labeled-image"]
|
|
|
+ provider.flush()
|
|
|
+
|
|
|
+ providers[provider.id] = provider
|
|
|
+ provider.commit()
|
|
|
+
|
|
|
+ self.assertEqual(n, LabelProvider.query.count())
|
|
|
+ response = self.get(url)
|
|
|
+ self.assertTrue(response.is_json)
|
|
|
+ content = response.json
|
|
|
+ self.assertEqual(n, len(response.json))
|
|
|
+
|
|
|
+ for entry in content:
|
|
|
+ provider = providers[entry["id"]]
|
|
|
+ self.assertDictEqual(provider.serialize(), entry)
|