6
0
فهرست منبع

added tests for the creation of results

Dimitri Korsch 3 سال پیش
والد
کامیت
751d311d3b
3فایلهای تغییر یافته به همراه81 افزوده شده و 69 حذف شده
  1. 3 3
      pycs/frontend/endpoints/results/CreateResult.py
  2. 0 60
      pycs/frontend/endpoints/results/ResultAsCrop.py
  3. 78 6
      tests/client/result_tests.py

+ 3 - 3
pycs/frontend/endpoints/results/CreateResult.py

@@ -32,7 +32,7 @@ class CreateResult(View):
             abort(400, "result type must be either 'labeled-image' or 'bounding-box'")
             abort(400, "result type must be either 'labeled-image' or 'bounding-box'")
 
 
         label = None
         label = None
-        data = {}
+        data = None
 
 
         if result_type == 'labeled-image':
         if result_type == 'labeled-image':
             label = request_data.get('label')
             label = request_data.get('label')
@@ -49,10 +49,10 @@ class CreateResult(View):
 
 
     def dispatch_request(self, file_id: int):
     def dispatch_request(self, file_id: int):
 
 
+        file = File.get_or_404(file_id)
+
         result_type, label, data = self.extract_request_data()
         result_type, label, data = self.extract_request_data()
 
 
-        # find file
-        file = File.get_or_404(file_id)
 
 
         # start transaction
         # start transaction
         with db.session.begin_nested():
         with db.session.begin_nested():

+ 0 - 60
pycs/frontend/endpoints/results/ResultAsCrop.py

@@ -1,60 +0,0 @@
-import os
-
-from flask import abort
-from flask import send_from_directory
-from flask.views import View
-
-from pycs.database.Result import Result
-from pycs.util import file_ops
-
-
-class ResultAsCrop(View):
-    """
-    return the image crop defined by the result.
-    """
-    # pylint: disable=arguments-differ
-    methods = ['GET']
-
-
-    def dispatch_request(self, result_id: int, max_width: int = 2**24, max_height: int = 2**24):
-
-        # find result
-        result = Result.get_or_404(result_id)
-
-        if result.type != "bounding-box":
-            msg = f"The type of the queried result was not \"bounding-box\"! It was {result.type}"
-            abort(400, msg)
-
-        file = result.file
-
-        if file.type != "image":
-            abort(400, "Currently only supporting images!")
-
-
-        data = result.data
-
-        if data is None:
-            abort(400, "The data of the result was None!")
-
-        xywh = [data.get(attr, -1) for attr in "xywh"]
-        if -1 in xywh:
-            abort(400, f"The data of the result is not correct: {data}!")
-
-
-        crop_path, crop_fname = file_ops.crop_file(file, file.project.root_folder, *xywh)
-
-        parts = os.path.splitext(crop_fname)
-
-        crop_new_fname = f"{parts[0]}_{max_width}_{max_height}.{parts[1]}"
-
-        resized = file_ops.resize_image(
-            os.path.join(crop_path, crop_fname),
-            os.path.join(crop_path, crop_new_fname),
-            max_width,
-            max_height
-        )
-
-        if resized:
-            crop_fname = crop_new_fname
-
-        return send_from_directory(crop_path, crop_fname)

+ 78 - 6
tests/client/result_tests.py

@@ -1,18 +1,90 @@
+import uuid
+
 from flask import url_for
 from flask import url_for
 
 
-from pycs.database.Model import Model
-from pycs.database.Project import Project
+from pycs.database.Label import Label
+from pycs.database.Result import Result
 
 
-from tests.client.label_tests import _BaseLabelTests
+from tests.client.file_tests import _BaseFileTests
 
 
 
 
-class _BaseResultTests(_BaseLabelTests):
+class _BaseResultTests(_BaseFileTests):
 
 
     def setupModels(self):
     def setupModels(self):
         super().setupModels()
         super().setupModels()
 
 
+        file_uuid = str(uuid.uuid1())
+        self.file, is_new = self.project.add_file(
+            uuid=file_uuid,
+            file_type="image",
+            name=f"name",
+            filename=f"image",
+            extension=".jpg",
+            size=32*1024,
+        )
+        assert is_new, "the created file should be new!"
 
 
+class ResultCreationTests(_BaseResultTests):
+
+    def test_missing_file(self):
+        url = url_for("create_result", file_id=4242)
+        self.post(url, status_code=404)
+
+    def test_missing_flags(self):
+        url = url_for("create_result", file_id=self.file.id)
+
+
+        bad_data = [
+            None, # no request data at all
+            dict(), # type missing
+            dict(type="something"), # should be "labeled-image" or "bounding-box"
+            dict(type="labeled_image"), # should be with "-"
+            dict(type="bounding_box"), # should be with "-"
+            dict(type="labeled-image"), # label is missing
+            dict(type="bounding-box"), # data is missing
+        ]
+
+        for data in bad_data:
+            self.assertEqual(0, Result.query.count())
+            self.post(url, status_code=400, json=data)
+            self.assertEqual(0, Result.query.count())
+
+    def test_file_label(self):
+        url = url_for("create_result", file_id=self.file.id)
+
+        label, is_new = self.project.create_label(name="label", reference="some_label")
+        self.assertTrue(is_new)
+
+        self.assertEqual(0, Result.query.count())
+        self.post(url, json=dict(type="labeled-image", label=label.id))
+        self.assertEqual(1, Result.query.count())
+
+        result = Result.query.one_or_none()
+
+        self.assertIsNotNone(result)
+        self.assertEqual("user", result.origin)
+        self.assertEqual(self.file.id, result.file_id)
+        self.assertEqual(label.id, result.label_id)
+        self.assertEqual(label.name, result.label.name)
+        self.assertIsNone(result.data_encoded)
+        self.assertIsNone(result.data)
+
+    def test_bounding_box(self):
+        url = url_for("create_result", file_id=self.file.id)
+
+        self.assertEqual(0, Result.query.count())
+        box = dict(x=0, y=0.5, w=1/3, h=1/4)
+        self.post(url, json=dict(type="bounding-box", data=box))
+        self.assertEqual(1, Result.query.count())
+
+        result = Result.query.one_or_none()
+
+        self.assertIsNotNone(result)
+        self.assertEqual("user", result.origin)
+        self.assertEqual(self.file.id, result.file_id)
+        self.assertIsNotNone(result.data_encoded)
+        self.assertDictEqual(box, result.data)
+
+        self.assertIsNone(result.label_id)
 
 
 
 
-class ResultCreationTests(_BaseResultTests):
-    pass