import eventlet import io import time from test.base import BaseTestCase from pycs.database.File import File from pycs.database.Label import Label from pycs.database.LabelProvider import LabelProvider from pycs.database.Model import Model from pycs.database.Project import Project from pycs.database.Result import Result class ClientTests(BaseTestCase): def setUp(self): super().setUp() label_provider = LabelProvider.query.filter_by(name="Fixed Label Provider v1").one() model = Model.query.filter_by(name="Fixed Base Model v1").one() self.label_provider_id = label_provider.id self.model_id = model.id def _do_request(self, request_func, *args, status_code=200, **kwargs): response = request_func(*args, follow_redirects=True, **kwargs) self.assertEqual(response.status_code, status_code, response.get_data().decode()) return response def _post(self, url, status_code=200, content_type=None, json=None, data=None): return self._do_request(self.client.post, url, status_code=status_code, json=json, data=data, content_type=content_type, ) def _get(self, url, status_code=200, content_type=None, json=None, data=None): return self._do_request(self.client.get, url, status_code=status_code, json=json, data=data, content_type=content_type, ) def test_project_creation_without_label_provider(self): self.assertEqual(0, Project.query.count()) self.assertEqual(0, Label.query.count()) self._post( "/projects", json=dict( name="some name", description="some description", model=self.model_id, label=None, external=None, ) ) self.assertEqual(1, Project.query.count()) project = Project.query.first() self.assertIsNotNone(project) self.assertIsNotNone(project.model) self.assertIsNone(project.label_provider) self.wait_for_coroutines() self.assertEqual(0, Label.query.count()) def test_project_creation(self): self.assertEqual(0, Project.query.count()) self.assertEqual(0, Label.query.count()) self._post( "/projects", json=dict( name="some name", description="some description", model=self.model_id, label=self.label_provider_id, external=None, ) ) self.assertEqual(1, Project.query.count()) project = Project.query.first() self.assertIsNotNone(project) self.assertIsNotNone(project.model) self.assertIsNotNone(project.label_provider) self.wait_for_coroutines() self.assertNotEqual(0, Label.query.count()) def test_adding_file_with_result(self): self._post("/projects", json=dict( name="some name", description="some description", model=self.model_id, label=self.label_provider_id, external=None, ) ) self.assertEqual(1, Project.query.count()) project = Project.query.first() self.wait_for_coroutines() self.assertEqual(0, File.query.count()) self._post(f"/projects/{project.id}/data", data=dict(file=(io.BytesIO(b"some content"), "image.jpg")), content_type="multipart/form-data", ) self.assertEqual(1, File.query.count()) file = File.query.first() self.assertEqual(0, Result.query.count()) self._post(f"data/{file.id}/results", json=dict( type="bounding-box", data=dict(x0=0, x1=0, y0=0, y1=0), label=2, ) ) self.assertEqual(1, Result.query.count()) def test_cascade_after_project_removal(self): self.assertEqual(0, File.query.count()) self.assertEqual(0, Result.query.count()) self.assertEqual(0, Label.query.count()) self.assertEqual(0, Project.query.count()) self._post("/projects", json=dict( name="some name", description="some description", model=self.model_id, label=self.label_provider_id, external=None, ) ) project = Project.query.first() project_id = project.id self.wait_for_coroutines() self._post(f"/projects/{project_id}/data", data=dict(file=(io.BytesIO(b"some content"), "image.jpg")), content_type="multipart/form-data", ) file = File.query.first() file_id = file.id self.wait_for_coroutines() self._post(f"data/{file_id}/results", json=dict( type="bounding-box", data=dict(x=0, y=0, w=0, h=0), label=2, ) ) self.assertNotEqual(0, File.query.count()) self.assertNotEqual(0, Result.query.count()) self.assertNotEqual(0, Label.query.count()) self.assertNotEqual(0, Project.query.count()) self.wait_for_coroutines() eventlet.sleep(3) self._post(f"/projects/{project_id}/remove", json=dict(remove=True), ) self.assertEqual(0, Project.query.count()) self.assertEqual(0, Label.query.count()) self.assertEqual(0, File.query.count()) self.assertEqual(0, Result.query.count()) def test_result_download(self): self._post("/projects", json=dict( name="some name", description="some description", model=self.model_id, label=self.label_provider_id, external=None, ) ) project = Project.query.first() project_id = project.id self.wait_for_coroutines() self._post(f"/projects/{project_id}/data", data=dict(file=(io.BytesIO(b"some content"), "image.jpg")), content_type="multipart/form-data", ) file = File.query.first() file_id = file.id self.wait_for_coroutines() self._post(f"data/{file_id}/results", json=dict( type="bounding-box", data=dict(x=0, y=0, w=0, h=0), label=2, ) ) self.wait_for_coroutines() eventlet.sleep(3) response = self._get(f"/projects/{project_id}/results") self.assertTrue(response.is_json) file = File.query.first() result = Result.query.first() self.assertEqual(1, len(response.json)) returned_file = response.json[0] self.assertEqual(returned_file["filename"], file.filename) self.assertEqual(1, len(returned_file["results"])) returned_result = returned_file["results"][0] xywh_should = [result.data.get(attr) for attr in "xywh"] xywh_is = [returned_result.get(attr) for attr in "xywh"] self.assertListEqual(xywh_is, xywh_should) self.assertEqual(returned_result["label"]["id"], result.label.id) self.assertEqual(returned_result["label"]["name"], result.label.name) def test_single_label_creation(self): self._post( "/projects", json=dict( name="some project", description="project description", model=self.model_id, label=None, external=None, ) ) project = Project.query.first() project_id = project.id self.assertIsNotNone(project) self.assertEqual(0, Label.query.filter(Label.project_id == project_id).count()) self._post( f"/projects/{project_id}/labels", json=dict( name=f"Label 1", ) ) self.assertEqual(1, Label.query.filter(Label.project_id == project_id).count()) def test_multiple_label_creation(self): self._post( "/projects", json=dict( name="some project", description="project description", model=self.model_id, label=None, external=None, ) ) project = Project.query.first() project_id = project.id self.assertIsNotNone(project) self.assertEqual(0, Label.query.filter(Label.project_id == project_id).count()) for i in range(1, 11): self._post( f"/projects/{project_id}/labels", json=dict(name=f"Label {i}") ) self.assertEqual(i, Label.query.filter(Label.project_id == project_id).count()) def test_label_removal(self): self._post( "/projects", json=dict( name="some project", description="project description", model=self.model_id, label=None, external=None, ) ) project = Project.query.first() project_id = project.id self.assertIsNotNone(project) self.assertEqual(0, Label.query.filter(Label.project_id == project_id).count()) for i in range(1, 11): self._post( f"/projects/{project_id}/labels", json=dict(name=f"Label {i}") ) self.assertEqual(10, Label.query.filter(Label.project_id == project_id).count()) label = Label.query.get(5) label_id = label.id self.assertIsNotNone(label) self._post( f"/projects/{project_id}/labels/{label_id}/remove", json=dict(remove=True) ) self.assertEqual(9, Label.query.filter(Label.project_id == project_id).count()) label = Label.query.get(5) self.assertIsNone(label)