6
0
ソースを参照

added result removal tests

Dimitri Korsch 3 年 前
コミット
f41a934677

+ 2 - 2
pycs/frontend/endpoints/results/RemoveResult.py

@@ -19,14 +19,14 @@ class RemoveResult(View):
         self.nm = nm
         self.nm = nm
 
 
     def dispatch_request(self, result_id: int):
     def dispatch_request(self, result_id: int):
+        result = Result.get_or_404(result_id)
+
         # extract request data
         # extract request data
         data = request.get_json(force=True)
         data = request.get_json(force=True)
 
 
         if not data.get('remove', False):
         if not data.get('remove', False):
             abort(400, "remove flag is missing")
             abort(400, "remove flag is missing")
 
 
-        # find result
-        result = Result.get_or_404(result_id)
 
 
         dump = result.delete()
         dump = result.delete()
 
 

+ 2 - 2
pycs/frontend/endpoints/results/ResetResults.py

@@ -19,14 +19,14 @@ class ResetResults(View):
         self.nm = nm
         self.nm = nm
 
 
     def dispatch_request(self, file_id: int):
     def dispatch_request(self, file_id: int):
+        file = File.get_or_404(file_id)
+
         # extract request data
         # extract request data
         data = request.get_json(force=True)
         data = request.get_json(force=True)
 
 
         if not data.get('reset', False):
         if not data.get('reset', False):
             abort(400)
             abort(400)
 
 
-        # find file
-        file = File.get_or_404(file_id)
 
 
         removed = []
         removed = []
 
 

+ 52 - 0
tests/client/result_tests.py

@@ -230,3 +230,55 @@ class ResultEditingTests(_BaseResultTests):
             self.post(url, json=dict(confirm=True))
             self.post(url, json=dict(confirm=True))
             self.assertEqual("user", result.origin)
             self.assertEqual("user", result.origin)
 
 
+class ResultRemovalTests(_BaseResultTests):
+
+    def test_remove_result(self):
+        self.post(url_for("remove_result", result_id=4242), status_code=404)
+
+        label, is_new = self.project.create_label(name="label", reference="some_label1")
+
+        self.assertEqual(0, Result.query.count())
+        for result_type in ["labeled-image", "bounding-box"]:
+            result = self.file.create_result("pipeline", result_type, label=label)
+            self.assertEqual(1, Result.query.count())
+            self.assertEqual(1, self.file.results.count())
+
+            url = url_for("remove_result", result_id=result.id)
+
+            for data in [None, dict(), dict(remove=False)]:
+                self.post(url, status_code=400, json=data)
+
+            self.assertEqual(1, Result.query.count())
+            self.assertEqual(1, self.file.results.count())
+
+            self.post(url, json=dict(remove=True))
+
+            self.assertEqual(0, Result.query.count())
+            self.assertEqual(0, self.file.results.count())
+
+        self.assertEqual(0, Result.query.count())
+
+    def test_reset_file_results(self):
+
+        self.post(url_for("reset_results", file_id=4242), status_code=404)
+
+        label, is_new = self.project.create_label(name="label", reference="some_label1")
+
+        self.assertEqual(0, Result.query.count())
+
+        n = 5
+        for result_type in ["labeled-image", "bounding-box"]:
+            for i in range(n):
+                self.file.create_result("pipeline", result_type, label=label)
+
+            url = url_for("reset_results", file_id=self.file.id)
+
+            for data in [None, dict(), dict(reset=False)]:
+                self.post(url, status_code=400, json=data)
+
+            self.assertEqual(n, Result.query.count())
+            self.assertEqual(n, self.file.results.count())
+            self.post(url, json=dict(reset=True))
+
+            self.assertEqual(0, Result.query.count())
+            self.assertEqual(0, self.file.results.count())