import uuid from flask import url_for from pathlib import Path from pycs.database.Model import Model from pycs.database.LabelProvider import LabelProvider from pycs.database.Project import Project from tests.base import BaseTestCase from tests.base import pаtch_tpool_execute class _BasePipelineTests(BaseTestCase): def setupModels(self): super().setupModels() Model.discover("tests/client/test_models") self.model = Model.query.one() self.project = self.new_project() root = Path(self.project.root_folder) data_root = Path(self.project.data_folder) for folder in [data_root, root / "temp"]: folder.mkdir(exist_ok=True, parents=True) file_uuid = str(uuid.uuid1()) self.file, is_new = self.project.add_file( uuid=file_uuid, file_type="image", name="name", filename="image", extension=".jpg", size=32*1024, ) self.assertTrue(is_new) with open(self.file.absolute_path, "wb") as f: f.write(b"some content") def new_project(self, **kwargs): return Project.new( name="test_project", description="Project for a test case", model=self.model, root_folder="project_folder", external_data=False, data_folder="project_folder/data", **kwargs ) def tearDown(self): self.wait_for_bg_jobs(raise_errors=False) self.project.delete() super().tearDown() class ModelPipelineTests(_BasePipelineTests): _sleep_time: float = .2 def test_predict_file_busy(self): url = url_for("predict_file", file_id=self.file.id) self.post(url, json=dict(predict=True)) self.post(url, json=dict(predict=True), status_code=400) def test_predict_file_errors(self): self.post(url_for("predict_file", file_id=4242), status_code=404) url = url_for("predict_file", file_id=self.file.id) for data in [None, dict(), dict(predict=False)]: self.post(url, status_code=400, json=data) def test_predict_file(self): url = url_for("predict_file", file_id=self.file.id) self.assertEqual(0, self.file.results.count()) self.post(url, json=dict(predict=True)) self.wait_for_bg_jobs() self.assertEqual(1, self.file.results.count()) def test_predict_file_multiple_times(self): url = url_for("predict_file", file_id=self.file.id) self.assertEqual(0, self.file.results.count()) self.post(url, json=dict(predict=True)) self.wait_for_bg_jobs() self.assertEqual(1, self.file.results.count()) self.post(url, json=dict(predict=True)) self.wait_for_bg_jobs() self.assertEqual(1, self.file.results.count()) def test_predict_model_errors(self): self.post(url_for("predict_model", project_id=4242), status_code=404) url = url_for("predict_model", project_id=self.project.id) for data in [None, dict(), dict(predict=False), dict(predict=True), dict(predict="not new or all")]: self.post(url, status_code=400, json=data) def test_predict_model_busy(self): url = url_for("predict_model", project_id=self.project.id) self.post(url, json=dict(predict="new")) self.post(url, json=dict(predict="new"), status_code=400) def test_predict_model_for_new(self): url = url_for("predict_model", project_id=self.project.id) self.post(url, json=dict(predict="new")) def test_predict_model_for_all(self): url = url_for("predict_model", project_id=self.project.id) self.post(url, json=dict(predict="all")) def test_model_fit_errors(self): self.post(url_for("fit_model", project_id=4242), status_code=404) url = url_for("fit_model", project_id=self.project.id) for data in [None, dict(), dict(fit=False)]: self.post(url, status_code=400, json=data) def test_model_fit_busy(self): url = url_for("fit_model", project_id=self.project.id) self.post(url, json=dict(fit=True)) self.post(url, json=dict(fit=True), status_code=400) def test_model_fit(self): url = url_for("fit_model", project_id=self.project.id) self.post(url, json=dict(fit=True)) class LabelProviderPipelineTests: def new_project(self): LabelProvider.discover("tests/client/test_labels") return super().new_project(label_provider=self.label_provider) @property def url(self): return url_for("execute_label_provider", project_id=self.project.id) def test_label_provider_errors(self): url = url_for("execute_label_provider", project_id=4242) self.post(url, status_code=404) for data in [None, dict(), dict(execute=False)]: self.post(self.url, json=data, status_code=400) self.project.label_provider = None self.project.commit() self.post(self.url, json=dict(execute=True), status_code=400) def test_label_provider_busy(self): self.post(self.url, json=dict(execute=True)) self.post(self.url, json=dict(execute=True), status_code=400) def test_label_loading(self): self.post(self.url, json=dict(execute=True)) self.wait_for_bg_jobs() self.assertEqual(self.n_labels, self.project.labels.count()) def test_label_loading_multiple(self): for i in range(3): self.post(self.url, json=dict(execute=True)) self.wait_for_bg_jobs() self.assertEqual(self.n_labels, self.project.labels.count()) class SimpleLabelProviderPipelineTests(LabelProviderPipelineTests, _BasePipelineTests): @property def n_labels(self): return 10 @property def label_provider(self): name_filter = LabelProvider.name.contains("Simple") return LabelProvider.query.filter(name_filter).one() class HierarchicalLabelProviderPipelineTests(LabelProviderPipelineTests, _BasePipelineTests): @property def n_labels(self): leafs = 10 * 3 * 3 intermediate = 10 * 3 roots = 10 return roots + intermediate + leafs @property def label_provider(self): name_filter = LabelProvider.name.contains("Hierarchical") return LabelProvider.query.filter(name_filter).one()