6
0
Эх сурвалжийг харах

added tests for label provider execution. fixed the unnecessary transaction nesting, that resulted in DB errors

Dimitri Korsch 3 жил өмнө
parent
commit
3d473d8765

+ 1 - 1
Makefile

@@ -12,7 +12,7 @@ run_tests:
 
 
 run_coverage:
 run_coverage:
 	@PYTHONWARNINGS="ignore::ResourceWarning" coverage run -m unittest discover tests/
 	@PYTHONWARNINGS="ignore::ResourceWarning" coverage run -m unittest discover tests/
-	@coverage run --concurrency=eventlet --append -m unittest tests.PipelineTests
+	@coverage run --concurrency=eventlet --append -m unittest tests.client.pipeline_tests
 	coverage html
 	coverage html
 	coverage report -m
 	coverage report -m
 
 

+ 13 - 6
pycs/frontend/endpoints/projects/ExecuteLabelProvider.py

@@ -14,7 +14,7 @@ from pycs.jobs.JobRunner import JobRunner
 
 
 
 
 class ExecuteLabelProvider(View):
 class ExecuteLabelProvider(View):
-    """db
+    """
     execute the label provider associated with a passed project identifier
     execute the label provider associated with a passed project identifier
     """
     """
     # pylint: disable=arguments-differ
     # pylint: disable=arguments-differ
@@ -26,15 +26,14 @@ class ExecuteLabelProvider(View):
         self.jobs = jobs
         self.jobs = jobs
 
 
     def dispatch_request(self, project_id: int):
     def dispatch_request(self, project_id: int):
+        project = Project.get_or_404(project_id)
+
         # extract request data
         # extract request data
         data = request.get_json(force=True)
         data = request.get_json(force=True)
 
 
         if not data.get('execute', False):
         if not data.get('execute', False):
             abort(400, "execute flag is missing")
             abort(400, "execute flag is missing")
 
 
-        # find project
-        project = Project.get_or_404(project_id)
-
         # get label provider
         # get label provider
         label_provider = project.label_provider
         label_provider = project.label_provider
         if label_provider is None:
         if label_provider is None:
@@ -71,13 +70,21 @@ class ExecuteLabelProvider(View):
         project_id = project.id
         project_id = project.id
         # result adds the received labels to the database and fires events
         # result adds the received labels to the database and fires events
         def result(provided_labels):
         def result(provided_labels):
-            with db.session.begin():
+            nonlocal project_id
+
+            def _inner(labels, project_id):
                 project = Project.query.get(project_id)
                 project = Project.query.get(project_id)
-                labels = project.bulk_create_labels(provided_labels)
+                labels = project.bulk_create_labels(labels)
 
 
                 for label in labels:
                 for label in labels:
                     nm.create_label(label)
                     nm.create_label(label)
 
 
+            if db.session().in_transaction():
+                _inner(provided_labels, project_id)
+            else:
+                with db.session.begin():
+                    _inner(provided_labels, project_id)
+
         # run job with given functions
         # run job with given functions
         jobs.run(project,
         jobs.run(project,
                  'Label Provider',
                  'Label Provider',

+ 16 - 6
pycs/util/PipelineCache.py

@@ -1,10 +1,13 @@
+import queue
 import warnings
 import warnings
+import eventlet
 
 
-from queue import Queue
 from threading import Lock
 from threading import Lock
-from time import time, sleep
+from time import sleep
+from time import time
 
 
-from eventlet import tpool, spawn_n
+from eventlet import spawn_n
+from eventlet import tpool
 
 
 from pycs import app
 from pycs import app
 from pycs.database.Project import Project
 from pycs.database.Project import Project
@@ -25,7 +28,7 @@ class PipelineCache:
 
 
         self.__pipelines = {}
         self.__pipelines = {}
         self.__is_running = False
         self.__is_running = False
-        self.__queue = Queue()
+        self.__queue = queue.Queue()
         self.__lock = Lock()
         self.__lock = Lock()
 
 
         self._cache_time = cache_time or self.CLOSE_TIMER
         self._cache_time = cache_time or self.CLOSE_TIMER
@@ -43,7 +46,7 @@ class PipelineCache:
     @property
     @property
     def is_empty(self):
     def is_empty(self):
         """ checks whether the pipeline cache is empty """
         """ checks whether the pipeline cache is empty """
-        return len(self.__pipelines) == 0
+        return len(self.__pipelines) == 0 and self.__queue.empty()
 
 
     def shutdown(self):
     def shutdown(self):
         """ puts None in the queue to signal the worker to stop """
         """ puts None in the queue to signal the worker to stop """
@@ -103,7 +106,14 @@ class PipelineCache:
     def __get(self):
     def __get(self):
         while True:
         while True:
             # get element from queue
             # get element from queue
-            entry = self.__queue.get()
+            while True:
+                try:
+                    entry = self.__queue.get(block=False)
+                    break
+                except queue.Empty:
+                    eventlet.sleep(0.2)
+                    pass
+
             if entry is None:
             if entry is None:
                 # closing pipeline cache
                 # closing pipeline cache
                 return None
                 return None

+ 10 - 5
tests/base.py

@@ -1,9 +1,9 @@
 import eventlet
 import eventlet
+import logging
 import os
 import os
 import shutil
 import shutil
 import typing as T
 import typing as T
 import unittest
 import unittest
-import logging
 
 
 from pathlib import Path
 from pathlib import Path
 from unittest import mock
 from unittest import mock
@@ -34,6 +34,7 @@ class BaseTestCase(unittest.TestCase):
 
 
     DB_FILE = Path.cwd() / "test.sqlite"
     DB_FILE = Path.cwd() / "test.sqlite"
 
 
+
     @classmethod
     @classmethod
     def setUpClass(cls, discovery: bool = False):
     def setUpClass(cls, discovery: bool = False):
         global server
         global server
@@ -75,18 +76,24 @@ class BaseTestCase(unittest.TestCase):
             if ready:
             if ready:
                 break
                 break
 
 
-            eventlet.sleep(self._sleep_time)
+            self.wait_for_coroutines()
 
 
         # wait for PipelineCache to finish
         # wait for PipelineCache to finish
 
 
         while not self.server.pipelines.is_empty:
         while not self.server.pipelines.is_empty:
-            eventlet.sleep(self._sleep_time)
+            self.wait_for_coroutines()
+
+    def wait_for_coroutines(self):
+        eventlet.sleep(self._sleep_time)
 
 
     @classmethod
     @classmethod
     def tearDownClass(cls):
     def tearDownClass(cls):
         super().tearDownClass()
         super().tearDownClass()
         cls.server.pipelines.shutdown()
         cls.server.pipelines.shutdown()
 
 
+        while not cls.server.pipelines.is_empty:
+            eventlet.sleep(cls._sleep_time)
+
 
 
 
 
     def setUp(self):
     def setUp(self):
@@ -156,5 +163,3 @@ class BaseTestCase(unittest.TestCase):
             data=data,
             data=data,
         )
         )
 
 
-    def wait_for_coroutines(self):
-        eventlet.sleep(self._sleep_time)

+ 86 - 12
tests/client/pipeline_tests.py

@@ -4,16 +4,14 @@ from flask import url_for
 from pathlib import Path
 from pathlib import Path
 
 
 from pycs.database.Model import Model
 from pycs.database.Model import Model
+from pycs.database.LabelProvider import LabelProvider
 from pycs.database.Project import Project
 from pycs.database.Project import Project
 
 
 from tests.base import BaseTestCase
 from tests.base import BaseTestCase
 from tests.base import pаtch_tpool_execute
 from tests.base import pаtch_tpool_execute
 
 
 
 
-
-class PipelineTests(BaseTestCase):
-
-    _sleep_time = .2
+class _BasePipelineTests(BaseTestCase):
 
 
     def setupModels(self):
     def setupModels(self):
         super().setupModels()
         super().setupModels()
@@ -22,14 +20,8 @@ class PipelineTests(BaseTestCase):
 
 
         self.model = Model.query.one()
         self.model = Model.query.one()
 
 
-        self.project = Project.new(
-            name="test_project",
-            description="Project for a test case",
-            model=self.model,
-            root_folder="project_folder",
-            external_data=False,
-            data_folder="project_folder/data",
-        )
+        self.project = self.new_project()
+
         root = Path(self.project.root_folder)
         root = Path(self.project.root_folder)
         data_root = Path(self.project.data_folder)
         data_root = Path(self.project.data_folder)
 
 
@@ -50,12 +42,27 @@ class PipelineTests(BaseTestCase):
         with open(self.file.absolute_path, "wb") as f:
         with open(self.file.absolute_path, "wb") as f:
             f.write(b"some content")
             f.write(b"some content")
 
 
+    def new_project(self, **kwargs):
+
+        return Project.new(
+            name="test_project",
+            description="Project for a test case",
+            model=self.model,
+            root_folder="project_folder",
+            external_data=False,
+            data_folder="project_folder/data",
+            **kwargs
+        )
 
 
     def tearDown(self):
     def tearDown(self):
         self.wait_for_bg_jobs(raise_errors=False)
         self.wait_for_bg_jobs(raise_errors=False)
         self.project.delete()
         self.project.delete()
         super().tearDown()
         super().tearDown()
 
 
+class ModelPipelineTests(_BasePipelineTests):
+
+    _sleep_time: float = .2
+
     def test_predict_file_busy(self):
     def test_predict_file_busy(self):
         url = url_for("predict_file", file_id=self.file.id)
         url = url_for("predict_file", file_id=self.file.id)
 
 
@@ -136,3 +143,70 @@ class PipelineTests(BaseTestCase):
     def test_model_fit(self):
     def test_model_fit(self):
         url = url_for("fit_model", project_id=self.project.id)
         url = url_for("fit_model", project_id=self.project.id)
         self.post(url, json=dict(fit=True))
         self.post(url, json=dict(fit=True))
+
+
+class LabelProviderPipelineTests:
+
+    def new_project(self):
+        LabelProvider.discover("tests/client/test_labels")
+        return super().new_project(label_provider=self.label_provider)
+
+    @property
+    def url(self):
+        return url_for("execute_label_provider", project_id=self.project.id)
+
+    def test_label_provider_errors(self):
+        url = url_for("execute_label_provider", project_id=4242)
+        self.post(url, status_code=404)
+
+        for data in [None, dict(), dict(execute=False)]:
+            self.post(self.url, json=data, status_code=400)
+
+        self.project.label_provider = None
+        self.project.commit()
+        self.post(self.url, json=dict(execute=True), status_code=400)
+
+    def test_label_provider_busy(self):
+        self.post(self.url, json=dict(execute=True))
+        self.post(self.url, json=dict(execute=True), status_code=400)
+
+    def test_label_loading(self):
+        self.post(self.url, json=dict(execute=True))
+        self.wait_for_bg_jobs()
+
+        self.assertEqual(self.n_labels, self.project.labels.count())
+
+    def test_label_loading_multiple(self):
+
+        for i in range(3):
+            self.post(self.url, json=dict(execute=True))
+            self.wait_for_bg_jobs()
+
+            self.assertEqual(self.n_labels, self.project.labels.count())
+
+class SimpleLabelProviderPipelineTests(LabelProviderPipelineTests, _BasePipelineTests):
+
+    @property
+    def n_labels(self):
+        return 10
+
+    @property
+    def label_provider(self):
+        name_filter = LabelProvider.name.contains("Simple")
+        return LabelProvider.query.filter(name_filter).one()
+
+
+class HierarchicalLabelProviderPipelineTests(LabelProviderPipelineTests, _BasePipelineTests):
+
+    @property
+    def n_labels(self):
+        leafs = 10 * 3 * 3
+        intermediate = 10 * 3
+        roots = 10
+
+        return roots + intermediate + leafs
+
+    @property
+    def label_provider(self):
+        name_filter = LabelProvider.name.contains("Hierarchical")
+        return LabelProvider.query.filter(name_filter).one()

+ 9 - 0
tests/client/test_labels/simple_labels/configuration1.json

@@ -0,0 +1,9 @@
+{
+  "name": "Simple Labels",
+  "description": "Simple labels for tests",
+  "code": {
+    "module": "provider",
+    "class": "LabelProvider"
+  },
+  "hierarchical": false
+}

+ 9 - 0
tests/client/test_labels/simple_labels/configuration2.json

@@ -0,0 +1,9 @@
+{
+  "name": "Hierarchical Labels",
+  "description": "Hierarchical labels for tests",
+  "code": {
+    "module": "provider",
+    "class": "LabelProvider"
+  },
+  "hierarchical": true
+}

+ 51 - 0
tests/client/test_labels/simple_labels/provider.py

@@ -0,0 +1,51 @@
+import typing as T
+
+from pycs.interfaces.LabelProvider import LabelProvider as BaseProvider
+
+class LabelProvider(BaseProvider):
+
+    def __init__(self, root_folder: str, configuration: dict):
+        self.hierarchical = configuration["hierarchical"]
+
+    def close(self):
+        pass
+
+    def get_labels(self) -> T.List[dict]:
+
+        labels = []
+        for i, _ in enumerate(range(10), 1):
+
+            label = self.create_label(
+                reference=f"label{i}",
+                name=f"Label{i}",
+                hierarchy_level="root"
+            )
+            labels.append(label)
+
+            if not self.hierarchical:
+                continue
+
+            parent = label
+            for j, _ in enumerate(range(3), 1):
+
+                label = self.create_label(
+                    reference=f"label{i}_{j}",
+                    name=f"Label{i}_{j}",
+                    parent_reference=parent["reference"],
+                    hierarchy_level="intermediate",
+                )
+
+                labels.append(label)
+
+                second_parent = label
+                for k, _ in enumerate(range(3), 1):
+                    label = self.create_label(
+                        reference=f"label{i}_{j}_{k}",
+                        name=f"Label{i}_{j}_{k}",
+                        parent_reference=second_parent["reference"],
+                        hierarchy_level="leaf",
+                    )
+
+                    labels.append(label)
+
+        return labels