Explorar o código

added tests for editing results. fixed some minor bugs

Dimitri Korsch %!s(int64=4) %!d(string=hai) anos
pai
achega
a31522ce27

+ 4 - 2
pycs/frontend/endpoints/results/ConfirmResult.py

@@ -17,14 +17,16 @@ class ConfirmResult(View):
         self.nm = nm
         self.nm = nm
 
 
     def dispatch_request(self, result_id: int):
     def dispatch_request(self, result_id: int):
+        # find result
+        result = Result.get_or_404(result_id)
+
         # extract request data
         # extract request data
         data = request.get_json(force=True)
         data = request.get_json(force=True)
 
 
+
         if not data.get('confirm', False):
         if not data.get('confirm', False):
             return abort(400, "confirm flag is missing")
             return abort(400, "confirm flag is missing")
 
 
-        # find result
-        result = Result.get_or_404(result_id)
 
 
         result.set_origin('user')
         result.set_origin('user')
 
 

+ 3 - 3
pycs/frontend/endpoints/results/EditResultData.py

@@ -19,6 +19,9 @@ class EditResultData(View):
         self.nm = nm
         self.nm = nm
 
 
     def dispatch_request(self, result_id: int):
     def dispatch_request(self, result_id: int):
+        # find result
+        result = Result.get_or_404(result_id)
+
         # extract request data
         # extract request data
         request_data = request.get_json(force=True)
         request_data = request.get_json(force=True)
         data = request_data.get('data')
         data = request_data.get('data')
@@ -26,9 +29,6 @@ class EditResultData(View):
         if data is None:
         if data is None:
             abort(400, "Could not find data argument!")
             abort(400, "Could not find data argument!")
 
 
-        # find result
-        result = Result.get_or_404(result_id)
-
         result.data = data
         result.data = data
         result.set_origin('user', commit=True)
         result.set_origin('user', commit=True)
 
 

+ 5 - 4
pycs/frontend/endpoints/results/EditResultLabel.py

@@ -19,21 +19,22 @@ class EditResultLabel(View):
         self.nm = nm
         self.nm = nm
 
 
     def dispatch_request(self, result_id: int):
     def dispatch_request(self, result_id: int):
+        # find result
+        result = Result.get_or_404(result_id)
+
         # extract request data
         # extract request data
         request_data = request.get_json(force=True)
         request_data = request.get_json(force=True)
 
 
         if 'label' not in request_data:
         if 'label' not in request_data:
             abort(400, "Could not find label argument!")
             abort(400, "Could not find label argument!")
 
 
-        # find result
-        result = Result.get_or_404(result_id)
         label = request_data.get('label')
         label = request_data.get('label')
 
 
         # abort if label is empty for labeled-images
         # abort if label is empty for labeled-images
         if result.type == 'labeled-image' and label is None:
         if result.type == 'labeled-image' and label is None:
-            return abort(400, "Label is required for 'labeled-images' results")
+            abort(400, "Label is required for 'labeled-images' results")
 
 
-        result.label = label
+        result.label_id = label
         result.set_origin('user', commit=True)
         result.set_origin('user', commit=True)
 
 
         self.nm.edit_result(result)
         self.nm.edit_result(result)

+ 91 - 0
tests/client/result_tests.py

@@ -139,3 +139,94 @@ class ResultGettingTests(_BaseResultTests):
 
 
             self.assertDictEqual(res.serialize(), entry)
             self.assertDictEqual(res.serialize(), entry)
 
 
+
+class ResultEditingTests(_BaseResultTests):
+
+    def test_edit_data(self):
+
+        self.post(url_for("edit_result_data", result_id=4242), status_code=404)
+
+        box0 = dict(x=0, y=0, w=0.9, h=1.0)
+        result = self.file.create_result("pipeline", "bounding-box", data=box0)
+
+        self.assertDictEqual(box0, result.data)
+        self.assertEqual("pipeline", result.origin)
+
+
+        url = url_for("edit_result_data", result_id=result.id)
+
+        for data in [None, dict()]:
+            self.post(url, status_code=400, json=data)
+
+        box1 = dict(x=0, y=0, w=0.9, h=1.0)
+        self.post(url, json=dict(data=box1))
+
+        self.assertDictEqual(box1, result.data)
+        self.assertEqual("user", result.origin)
+
+    def test_edit_label(self):
+        self.post(url_for("edit_result_label", result_id=4242), status_code=404)
+
+        label1, is_new = self.project.create_label(name="label1", reference="some_label1")
+        self.assertTrue(is_new)
+
+        label2, is_new = self.project.create_label(name="label2", reference="some_label2")
+        self.assertTrue(is_new)
+
+        for result_type in ["labeled-image", "bounding-box"]:
+
+            result = self.file.create_result("pipeline", result_type, label=label1)
+
+            self.assertEqual(label1.id, result.label_id)
+            self.assertEqual(label1.name, result.label.name)
+            self.assertEqual("pipeline", result.origin)
+
+            url = url_for("edit_result_label", result_id=result.id)
+
+            for data in [None, dict()]:
+                self.post(url, status_code=400, json=data)
+
+            if result_type == "labeled-image":
+                self.post(url, status_code=400, json=dict(label=None))
+
+            self.post(url, json=dict(label=label2.id))
+
+            self.assertEqual(label2.id, result.label_id)
+            self.assertEqual(label2.name, result.label.name)
+            self.assertEqual("user", result.origin)
+
+
+    def test_unset_label_of_bounding_box(self):
+        label, is_new = self.project.create_label(name="label", reference="some_label")
+        self.assertTrue(is_new)
+
+        result = self.file.create_result("pipeline", "bounding-box", label=label)
+
+        self.assertEqual(label.id, result.label_id)
+        self.assertEqual(label.name, result.label.name)
+        self.assertEqual("pipeline", result.origin)
+
+        url = url_for("edit_result_label", result_id=result.id)
+
+        self.post(url, json=dict(label=None))
+
+        self.assertEqual(None, result.label_id)
+        self.assertEqual("user", result.origin)
+
+    def test_confirm_result(self):
+        self.post(url_for("confirm_result", result_id=4242), status_code=404)
+
+        label, is_new = self.project.create_label(name="label", reference="some_label1")
+
+        for result_type in ["labeled-image", "bounding-box"]:
+            result = self.file.create_result("pipeline", result_type, label=label)
+            url = url_for("confirm_result", result_id=result.id)
+
+
+            for data in [None, dict(), dict(confirm=False)]:
+                self.post(url, status_code=400, json=data)
+
+            self.assertEqual("pipeline", result.origin)
+            self.post(url, json=dict(confirm=True))
+            self.assertEqual("user", result.origin)
+