123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 |
- import uuid
- from flask import url_for
- from pycs.database.Label import Label
- from pycs.database.Result import Result
- from tests.client.file_tests import _BaseFileTests
- class _BaseResultTests(_BaseFileTests):
- def setupModels(self):
- super().setupModels()
- file_uuid = str(uuid.uuid1())
- self.file, is_new = self.project.add_file(
- uuid=file_uuid,
- file_type="image",
- name=f"name",
- filename=f"image",
- extension=".jpg",
- size=32*1024,
- )
- assert is_new, "the created file should be new!"
- class ResultCreationTests(_BaseResultTests):
- def test_missing_file(self):
- url = url_for("create_result", file_id=4242)
- self.post(url, status_code=404)
- def test_missing_flags(self):
- url = url_for("create_result", file_id=self.file.id)
- bad_data = [
- None, # no request data at all
- dict(), # type missing
- dict(type="something"), # should be "labeled-image" or "bounding-box"
- dict(type="labeled_image"), # should be with "-"
- dict(type="bounding_box"), # should be with "-"
- dict(type="labeled-image"), # label is missing
- dict(type="bounding-box"), # data is missing
- ]
- for data in bad_data:
- self.assertEqual(0, Result.query.count())
- self.post(url, status_code=400, json=data)
- self.assertEqual(0, Result.query.count())
- def test_file_label(self):
- url = url_for("create_result", file_id=self.file.id)
- label, is_new = self.project.create_label(name="label", reference="some_label")
- self.assertTrue(is_new)
- self.assertEqual(0, Result.query.count())
- self.post(url, json=dict(type="labeled-image", label=label.id))
- self.assertEqual(1, Result.query.count())
- result = Result.query.one_or_none()
- self.assertIsNotNone(result)
- self.assertEqual("user", result.origin)
- self.assertEqual(self.file.id, result.file_id)
- self.assertEqual(label.id, result.label_id)
- self.assertEqual(label.name, result.label.name)
- self.assertIsNone(result.data_encoded)
- self.assertIsNone(result.data)
- def test_bounding_box(self):
- url = url_for("create_result", file_id=self.file.id)
- self.assertEqual(0, Result.query.count())
- box = dict(x=0, y=0.5, w=1/3, h=1/4)
- self.post(url, json=dict(type="bounding-box", data=box))
- self.assertEqual(1, Result.query.count())
- result = Result.query.one_or_none()
- self.assertIsNotNone(result)
- self.assertEqual("user", result.origin)
- self.assertEqual(self.file.id, result.file_id)
- self.assertIsNotNone(result.data_encoded)
- self.assertDictEqual(box, result.data)
- self.assertIsNone(result.label_id)
- class ResultGettingTests(_BaseResultTests):
- def test_missing_file(self):
- url = url_for("get_results", file_id=4242)
- self.get(url, status_code=404)
- def test_getting_of_results(self):
- n = 5
- self.assertEqual(0, Result.query.count())
- results = {}
- for i in range(n):
- box = dict(x=0, y=0, w=0.9, h=1.0)
- res = self.file.create_result("user", "bounding-box", data=box)
- results[res.id] = res
- self.assertEqual(5, Result.query.count())
- file_uuid = str(uuid.uuid1())
- another_file, is_new = self.project.add_file(
- uuid=file_uuid,
- file_type="image",
- name=f"name2",
- filename=f"image2",
- extension=".jpg",
- size=32*1024,
- )
- self.assertTrue(is_new)
- for i in range(n):
- box = dict(x=0, y=0, w=0.9, h=1.0)
- another_file.create_result("user", "bounding-box", data=box)
- self.assertEqual(10, Result.query.count())
- url = url_for("get_results", file_id=self.file.id)
- response = self.get(url)
- self.assertTrue(response.is_json)
- content = response.json
- self.assertEqual(5, len(content))
- for entry in content:
- res = results[entry["id"]]
- self.assertDictEqual(res.serialize(), entry)
- class ResultEditingTests(_BaseResultTests):
- def test_edit_data(self):
- self.post(url_for("edit_result_data", result_id=4242), status_code=404)
- box0 = dict(x=0, y=0, w=0.9, h=1.0)
- result = self.file.create_result("pipeline", "bounding-box", data=box0)
- self.assertDictEqual(box0, result.data)
- self.assertEqual("pipeline", result.origin)
- url = url_for("edit_result_data", result_id=result.id)
- for data in [None, dict()]:
- self.post(url, status_code=400, json=data)
- box1 = dict(x=0, y=0, w=0.9, h=1.0)
- self.post(url, json=dict(data=box1))
- self.assertDictEqual(box1, result.data)
- self.assertEqual("user", result.origin)
- def test_edit_label(self):
- self.post(url_for("edit_result_label", result_id=4242), status_code=404)
- label1, is_new = self.project.create_label(name="label1", reference="some_label1")
- self.assertTrue(is_new)
- label2, is_new = self.project.create_label(name="label2", reference="some_label2")
- self.assertTrue(is_new)
- for result_type in ["labeled-image", "bounding-box"]:
- result = self.file.create_result("pipeline", result_type, label=label1)
- self.assertEqual(label1.id, result.label_id)
- self.assertEqual(label1.name, result.label.name)
- self.assertEqual("pipeline", result.origin)
- url = url_for("edit_result_label", result_id=result.id)
- for data in [None, dict()]:
- self.post(url, status_code=400, json=data)
- if result_type == "labeled-image":
- self.post(url, status_code=400, json=dict(label=None))
- self.post(url, json=dict(label=label2.id))
- self.assertEqual(label2.id, result.label_id)
- self.assertEqual(label2.name, result.label.name)
- self.assertEqual("user", result.origin)
- def test_unset_label_of_bounding_box(self):
- label, is_new = self.project.create_label(name="label", reference="some_label")
- self.assertTrue(is_new)
- result = self.file.create_result("pipeline", "bounding-box", label=label)
- self.assertEqual(label.id, result.label_id)
- self.assertEqual(label.name, result.label.name)
- self.assertEqual("pipeline", result.origin)
- url = url_for("edit_result_label", result_id=result.id)
- self.post(url, json=dict(label=None))
- self.assertEqual(None, result.label_id)
- self.assertEqual("user", result.origin)
- def test_confirm_result(self):
- self.post(url_for("confirm_result", result_id=4242), status_code=404)
- label, is_new = self.project.create_label(name="label", reference="some_label1")
- for result_type in ["labeled-image", "bounding-box"]:
- result = self.file.create_result("pipeline", result_type, label=label)
- url = url_for("confirm_result", result_id=result.id)
- for data in [None, dict(), dict(confirm=False)]:
- self.post(url, status_code=400, json=data)
- self.assertEqual("pipeline", result.origin)
- self.post(url, json=dict(confirm=True))
- self.assertEqual("user", result.origin)
- class ResultRemovalTests(_BaseResultTests):
- def test_remove_result(self):
- self.post(url_for("remove_result", result_id=4242), status_code=404)
- label, is_new = self.project.create_label(name="label", reference="some_label1")
- self.assertEqual(0, Result.query.count())
- for result_type in ["labeled-image", "bounding-box"]:
- result = self.file.create_result("pipeline", result_type, label=label)
- self.assertEqual(1, Result.query.count())
- self.assertEqual(1, self.file.results.count())
- url = url_for("remove_result", result_id=result.id)
- for data in [None, dict(), dict(remove=False)]:
- self.post(url, status_code=400, json=data)
- self.assertEqual(1, Result.query.count())
- self.assertEqual(1, self.file.results.count())
- self.post(url, json=dict(remove=True))
- self.assertEqual(0, Result.query.count())
- self.assertEqual(0, self.file.results.count())
- self.assertEqual(0, Result.query.count())
- def test_reset_file_results(self):
- self.post(url_for("reset_results", file_id=4242), status_code=404)
- label, is_new = self.project.create_label(name="label", reference="some_label1")
- self.assertEqual(0, Result.query.count())
- n = 5
- for result_type in ["labeled-image", "bounding-box"]:
- for i in range(n):
- self.file.create_result("pipeline", result_type, label=label)
- url = url_for("reset_results", file_id=self.file.id)
- for data in [None, dict(), dict(reset=False)]:
- self.post(url, status_code=400, json=data)
- self.assertEqual(n, Result.query.count())
- self.assertEqual(n, self.file.results.count())
- self.post(url, json=dict(reset=True))
- self.assertEqual(0, Result.query.count())
- self.assertEqual(0, self.file.results.count())
|