|
@@ -4,16 +4,14 @@ from flask import url_for
|
|
|
from pathlib import Path
|
|
|
|
|
|
from pycs.database.Model import Model
|
|
|
+from pycs.database.LabelProvider import LabelProvider
|
|
|
from pycs.database.Project import Project
|
|
|
|
|
|
from tests.base import BaseTestCase
|
|
|
from tests.base import pаtch_tpool_execute
|
|
|
|
|
|
|
|
|
-
|
|
|
-class PipelineTests(BaseTestCase):
|
|
|
-
|
|
|
- _sleep_time = .2
|
|
|
+class _BasePipelineTests(BaseTestCase):
|
|
|
|
|
|
def setupModels(self):
|
|
|
super().setupModels()
|
|
@@ -22,14 +20,8 @@ class PipelineTests(BaseTestCase):
|
|
|
|
|
|
self.model = Model.query.one()
|
|
|
|
|
|
- self.project = Project.new(
|
|
|
- name="test_project",
|
|
|
- description="Project for a test case",
|
|
|
- model=self.model,
|
|
|
- root_folder="project_folder",
|
|
|
- external_data=False,
|
|
|
- data_folder="project_folder/data",
|
|
|
- )
|
|
|
+ self.project = self.new_project()
|
|
|
+
|
|
|
root = Path(self.project.root_folder)
|
|
|
data_root = Path(self.project.data_folder)
|
|
|
|
|
@@ -50,12 +42,27 @@ class PipelineTests(BaseTestCase):
|
|
|
with open(self.file.absolute_path, "wb") as f:
|
|
|
f.write(b"some content")
|
|
|
|
|
|
+ def new_project(self, **kwargs):
|
|
|
+
|
|
|
+ return Project.new(
|
|
|
+ name="test_project",
|
|
|
+ description="Project for a test case",
|
|
|
+ model=self.model,
|
|
|
+ root_folder="project_folder",
|
|
|
+ external_data=False,
|
|
|
+ data_folder="project_folder/data",
|
|
|
+ **kwargs
|
|
|
+ )
|
|
|
|
|
|
def tearDown(self):
|
|
|
self.wait_for_bg_jobs(raise_errors=False)
|
|
|
self.project.delete()
|
|
|
super().tearDown()
|
|
|
|
|
|
+class ModelPipelineTests(_BasePipelineTests):
|
|
|
+
|
|
|
+ _sleep_time: float = .2
|
|
|
+
|
|
|
def test_predict_file_busy(self):
|
|
|
url = url_for("predict_file", file_id=self.file.id)
|
|
|
|
|
@@ -136,3 +143,70 @@ class PipelineTests(BaseTestCase):
|
|
|
def test_model_fit(self):
|
|
|
url = url_for("fit_model", project_id=self.project.id)
|
|
|
self.post(url, json=dict(fit=True))
|
|
|
+
|
|
|
+
|
|
|
+class LabelProviderPipelineTests:
|
|
|
+
|
|
|
+ def new_project(self):
|
|
|
+ LabelProvider.discover("tests/client/test_labels")
|
|
|
+ return super().new_project(label_provider=self.label_provider)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def url(self):
|
|
|
+ return url_for("execute_label_provider", project_id=self.project.id)
|
|
|
+
|
|
|
+ def test_label_provider_errors(self):
|
|
|
+ url = url_for("execute_label_provider", project_id=4242)
|
|
|
+ self.post(url, status_code=404)
|
|
|
+
|
|
|
+ for data in [None, dict(), dict(execute=False)]:
|
|
|
+ self.post(self.url, json=data, status_code=400)
|
|
|
+
|
|
|
+ self.project.label_provider = None
|
|
|
+ self.project.commit()
|
|
|
+ self.post(self.url, json=dict(execute=True), status_code=400)
|
|
|
+
|
|
|
+ def test_label_provider_busy(self):
|
|
|
+ self.post(self.url, json=dict(execute=True))
|
|
|
+ self.post(self.url, json=dict(execute=True), status_code=400)
|
|
|
+
|
|
|
+ def test_label_loading(self):
|
|
|
+ self.post(self.url, json=dict(execute=True))
|
|
|
+ self.wait_for_bg_jobs()
|
|
|
+
|
|
|
+ self.assertEqual(self.n_labels, self.project.labels.count())
|
|
|
+
|
|
|
+ def test_label_loading_multiple(self):
|
|
|
+
|
|
|
+ for i in range(3):
|
|
|
+ self.post(self.url, json=dict(execute=True))
|
|
|
+ self.wait_for_bg_jobs()
|
|
|
+
|
|
|
+ self.assertEqual(self.n_labels, self.project.labels.count())
|
|
|
+
|
|
|
+class SimpleLabelProviderPipelineTests(LabelProviderPipelineTests, _BasePipelineTests):
|
|
|
+
|
|
|
+ @property
|
|
|
+ def n_labels(self):
|
|
|
+ return 10
|
|
|
+
|
|
|
+ @property
|
|
|
+ def label_provider(self):
|
|
|
+ name_filter = LabelProvider.name.contains("Simple")
|
|
|
+ return LabelProvider.query.filter(name_filter).one()
|
|
|
+
|
|
|
+
|
|
|
+class HierarchicalLabelProviderPipelineTests(LabelProviderPipelineTests, _BasePipelineTests):
|
|
|
+
|
|
|
+ @property
|
|
|
+ def n_labels(self):
|
|
|
+ leafs = 10 * 3 * 3
|
|
|
+ intermediate = 10 * 3
|
|
|
+ roots = 10
|
|
|
+
|
|
|
+ return roots + intermediate + leafs
|
|
|
+
|
|
|
+ @property
|
|
|
+ def label_provider(self):
|
|
|
+ name_filter = LabelProvider.name.contains("Hierarchical")
|
|
|
+ return LabelProvider.query.filter(name_filter).one()
|