123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212 |
- import uuid
- from flask import url_for
- from pathlib import Path
- from pycs.database.Model import Model
- from pycs.database.LabelProvider import LabelProvider
- from pycs.database.Project import Project
- from tests.base import BaseTestCase
- from tests.base import pаtch_tpool_execute
- class _BasePipelineTests(BaseTestCase):
- def setupModels(self):
- super().setupModels()
- Model.discover("tests/client/test_models")
- self.model = Model.query.one()
- self.project = self.new_project()
- root = Path(self.project.root_folder)
- data_root = Path(self.project.data_folder)
- for folder in [data_root, root / "temp"]:
- folder.mkdir(exist_ok=True, parents=True)
- file_uuid = str(uuid.uuid1())
- self.file, is_new = self.project.add_file(
- uuid=file_uuid,
- file_type="image",
- name="name",
- filename="image",
- extension=".jpg",
- size=32*1024,
- )
- self.assertTrue(is_new)
- with open(self.file.absolute_path, "wb") as f:
- f.write(b"some content")
- def new_project(self, **kwargs):
- return Project.new(
- name="test_project",
- description="Project for a test case",
- model=self.model,
- root_folder="project_folder",
- external_data=False,
- data_folder="project_folder/data",
- **kwargs
- )
- def tearDown(self):
- self.wait_for_bg_jobs(raise_errors=False)
- self.project.delete()
- super().tearDown()
- class ModelPipelineTests(_BasePipelineTests):
- _sleep_time: float = .2
- def test_predict_file_busy(self):
- url = url_for("predict_file", file_id=self.file.id)
- self.post(url, json=dict(predict=True))
- self.post(url, json=dict(predict=True), status_code=400)
- def test_predict_file_errors(self):
- self.post(url_for("predict_file", file_id=4242),
- status_code=404)
- url = url_for("predict_file", file_id=self.file.id)
- for data in [None, dict(), dict(predict=False)]:
- self.post(url, status_code=400, json=data)
- def test_predict_file(self):
- url = url_for("predict_file", file_id=self.file.id)
- self.assertEqual(0, self.file.results.count())
- self.post(url, json=dict(predict=True))
- self.wait_for_bg_jobs()
- self.assertEqual(1, self.file.results.count())
- def test_predict_file_multiple_times(self):
- url = url_for("predict_file", file_id=self.file.id)
- self.assertEqual(0, self.file.results.count())
- self.post(url, json=dict(predict=True))
- self.wait_for_bg_jobs()
- self.assertEqual(1, self.file.results.count())
- self.post(url, json=dict(predict=True))
- self.wait_for_bg_jobs()
- self.assertEqual(1, self.file.results.count())
- def test_predict_model_errors(self):
- self.post(url_for("predict_model", project_id=4242),
- status_code=404)
- url = url_for("predict_model", project_id=self.project.id)
- for data in [None, dict(), dict(predict=False), dict(predict=True), dict(predict="not new or all")]:
- self.post(url, status_code=400, json=data)
- def test_predict_model_busy(self):
- url = url_for("predict_model", project_id=self.project.id)
- self.post(url, json=dict(predict="new"))
- self.post(url, json=dict(predict="new"), status_code=400)
- def test_predict_model_for_new(self):
- url = url_for("predict_model", project_id=self.project.id)
- self.post(url, json=dict(predict="new"))
- def test_predict_model_for_all(self):
- url = url_for("predict_model", project_id=self.project.id)
- self.post(url, json=dict(predict="all"))
- def test_model_fit_errors(self):
- self.post(url_for("fit_model", project_id=4242),
- status_code=404)
- url = url_for("fit_model", project_id=self.project.id)
- for data in [None, dict(), dict(fit=False)]:
- self.post(url, status_code=400, json=data)
- def test_model_fit_busy(self):
- url = url_for("fit_model", project_id=self.project.id)
- self.post(url, json=dict(fit=True))
- self.post(url, json=dict(fit=True), status_code=400)
- def test_model_fit(self):
- url = url_for("fit_model", project_id=self.project.id)
- self.post(url, json=dict(fit=True))
- class LabelProviderPipelineTests:
- def new_project(self):
- LabelProvider.discover("tests/client/test_labels")
- return super().new_project(label_provider=self.label_provider)
- @property
- def url(self):
- return url_for("execute_label_provider", project_id=self.project.id)
- def test_label_provider_errors(self):
- url = url_for("execute_label_provider", project_id=4242)
- self.post(url, status_code=404)
- for data in [None, dict(), dict(execute=False)]:
- self.post(self.url, json=data, status_code=400)
- self.project.label_provider = None
- self.project.commit()
- self.post(self.url, json=dict(execute=True), status_code=400)
- def test_label_provider_busy(self):
- self.post(self.url, json=dict(execute=True))
- self.post(self.url, json=dict(execute=True), status_code=400)
- def test_label_loading(self):
- self.post(self.url, json=dict(execute=True))
- self.wait_for_bg_jobs()
- self.assertEqual(self.n_labels, self.project.labels.count())
- def test_label_loading_multiple(self):
- for i in range(3):
- self.post(self.url, json=dict(execute=True))
- self.wait_for_bg_jobs()
- self.assertEqual(self.n_labels, self.project.labels.count())
- class SimpleLabelProviderPipelineTests(LabelProviderPipelineTests, _BasePipelineTests):
- @property
- def n_labels(self):
- return 10
- @property
- def label_provider(self):
- name_filter = LabelProvider.name.contains("Simple")
- return LabelProvider.query.filter(name_filter).one()
- class HierarchicalLabelProviderPipelineTests(LabelProviderPipelineTests, _BasePipelineTests):
- @property
- def n_labels(self):
- leafs = 10 * 3 * 3
- intermediate = 10 * 3
- roots = 10
- return roots + intermediate + leafs
- @property
- def label_provider(self):
- name_filter = LabelProvider.name.contains("Hierarchical")
- return LabelProvider.query.filter(name_filter).one()
|