Browse Source

added tests for listing models and label providers

Dimitri Korsch 3 years ago
parent
commit
ea9f7f3b26
1 changed files with 76 additions and 0 deletions
  1. 76 0
      tests/client/__init__.py

+ 76 - 0
tests/client/__init__.py

@@ -2,6 +2,9 @@ import tempfile
 
 from flask import url_for
 
+from pycs.database.Model import Model
+from pycs.database.LabelProvider import LabelProvider
+
 from tests.base import BaseTestCase
 from tests.client.file_tests import *
 from tests.client.label_tests import *
@@ -34,3 +37,76 @@ class FolderInformationTest(BaseTestCase):
                 f = tempfile.NamedTemporaryFile(dir=folder, delete=False)
 
 
+class ListModelsAndLabelProviders(BaseTestCase):
+
+    def test_list_models(self):
+        self.assertEqual(0, Model.query.count())
+        url = url_for("list_models")
+
+        response = self.get(url)
+        self.assertTrue(response.is_json)
+
+        self.assertEqual([], response.json)
+
+        models = {}
+        n = 5
+        for i, _ in enumerate(range(n), 1):
+
+            model = Model.new(
+                commit=False,
+                name=f"TestModel{i}",
+                description="Model for a test case #{i}",
+                root_folder=f"models/fixed_model{i}",
+            )
+            model.supports = ["labeled-image"]
+            model.flush()
+
+            models[model.id] = model
+        model.commit()
+
+        self.assertEqual(n, Model.query.count())
+        response = self.get(url)
+        self.assertTrue(response.is_json)
+        content = response.json
+        self.assertEqual(n, len(response.json))
+
+        for entry in content:
+            model = models[entry["id"]]
+            self.assertDictEqual(model.serialize(), entry)
+
+
+    def test_list_label_providers(self):
+        self.assertEqual(0, LabelProvider.query.count())
+        url = url_for("label_providers")
+
+        response = self.get(url)
+        self.assertTrue(response.is_json)
+
+        self.assertEqual([], response.json)
+
+        providers = {}
+        n = 5
+        for i, _ in enumerate(range(n), 1):
+
+            provider = LabelProvider.new(
+                commit=False,
+                name=f"Testprovider{i}",
+                description="LabelProvider for a test case #{i}",
+                root_folder=f"providers/fixed_provider{i}",
+                configuration_file=f"providers/fixed_provider{i}/config.json",
+            )
+            provider.supports = ["labeled-image"]
+            provider.flush()
+
+            providers[provider.id] = provider
+        provider.commit()
+
+        self.assertEqual(n, LabelProvider.query.count())
+        response = self.get(url)
+        self.assertTrue(response.is_json)
+        content = response.json
+        self.assertEqual(n, len(response.json))
+
+        for entry in content:
+            provider = providers[entry["id"]]
+            self.assertDictEqual(provider.serialize(), entry)