import io import time import eventlet from test.base import BaseTestCase from pycs.database.File import File from pycs.database.Result import Result from pycs.database.Label import Label from pycs.database.Project import Project class ClientTests(BaseTestCase): def _do_request(self, request_func, *args, status_code=200, **kwargs): response = request_func(*args, follow_redirects=True, **kwargs) self.assertEqual(response.status_code, status_code, response.get_data().decode()) return response def _post(self, url, status_code=200, content_type=None, json=None, data=None): return self._do_request(self.client.post, url, status_code=status_code, json=json, data=data, content_type=content_type, ) def _get(self, url, status_code=200, content_type=None, json=None, data=None): return self._do_request(self.client.get, url, status_code=status_code, json=json, data=data, content_type=content_type, ) def test_project_creation(self): self.assertEqual(0, Project.query.count()) self.assertEqual(0, Label.query.count()) self._post( "/projects", json=dict( name="some name", description="some description", model=1, label=2, external=None, ) ) self.assertEqual(1, Project.query.count()) project = Project.query.first() self.assertIsNotNone(project) self.assertIsNotNone(project.model) self.assertIsNotNone(project.label_provider) self.wait_for_coroutines() self.assertNotEqual(0, Label.query.count()) def test_adding_file_with_result(self): self._post("/projects", json=dict( name="some name", description="some description", model=1, label=2, external=None, ) ) self.assertEqual(1, Project.query.count()) project = Project.query.first() self.wait_for_coroutines() self.assertEqual(0, File.query.count()) self._post(f"/projects/{project.id}/data", data=dict(file=(io.BytesIO(b"some content"), "image.jpg")), content_type="multipart/form-data", ) self.assertEqual(1, File.query.count()) file = File.query.first() self.assertEqual(0, Result.query.count()) self._post(f"data/{file.id}/results", json=dict( type="bounding-box", data=dict(x0=0, x1=0, y0=0, y1=0), label=2, ) ) self.assertEqual(1, Result.query.count()) def test_cascade_after_project_removal(self): self.assertEqual(0, File.query.count()) self.assertEqual(0, Result.query.count()) self.assertEqual(0, Label.query.count()) self.assertEqual(0, Project.query.count()) self._post("/projects", json=dict( name="some name", description="some description", model=1, label=2, external=None, ) ) project = Project.query.first() project_id = project.id self.wait_for_coroutines() self._post(f"/projects/{project_id}/data", data=dict(file=(io.BytesIO(b"some content"), "image.jpg")), content_type="multipart/form-data", ) file = File.query.first() file_id = file.id self.wait_for_coroutines() self._post(f"data/{file_id}/results", json=dict( type="bounding-box", data=dict(x=0, y=0, w=0, h=0), label=2, ) ) self.assertNotEqual(0, File.query.count()) self.assertNotEqual(0, Result.query.count()) self.assertNotEqual(0, Label.query.count()) self.assertNotEqual(0, Project.query.count()) self.wait_for_coroutines() eventlet.sleep(3) self._post(f"/projects/{project_id}/remove", json=dict(remove=True), ) self.assertEqual(0, Project.query.count()) self.assertEqual(0, Label.query.count()) self.assertEqual(0, File.query.count()) self.assertEqual(0, Result.query.count()) def test_result_download(self): self._post("/projects", json=dict( name="some name", description="some description", model=1, label=2, external=None, ) ) project = Project.query.first() project_id = project.id self.wait_for_coroutines() self._post(f"/projects/{project_id}/data", data=dict(file=(io.BytesIO(b"some content"), "image.jpg")), content_type="multipart/form-data", ) file = File.query.first() file_id = file.id self.wait_for_coroutines() self._post(f"data/{file_id}/results", json=dict( type="bounding-box", data=dict(x=0, y=0, w=0, h=0), label=2, ) ) self.wait_for_coroutines() eventlet.sleep(3) response = self._get(f"/projects/{project_id}/results") self.assertTrue(response.is_json) file = File.query.first() result = Result.query.first() self.assertEqual(1, len(response.json)) returned_file = response.json[0] self.assertEqual(returned_file["filename"], file.filename) self.assertEqual(1, len(returned_file["results"])) returned_result = returned_file["results"][0] xywh_should = [result.data.get(attr) for attr in "xywh"] xywh_is = [returned_result.get(attr) for attr in "xywh"] self.assertListEqual(xywh_is, xywh_should) self.assertEqual(returned_result["label"]["id"], result.label.id) self.assertEqual(returned_result["label"]["name"], result.label.name)