6
0
Răsfoiți Sursa

added a test for result retrieval. Fixed some about for that

Dimitri Korsch 3 ani în urmă
părinte
comite
0d4d420aa0

+ 1 - 1
migrations/versions/ece44a2b50ef_.py

@@ -101,7 +101,7 @@ def upgrade():
     sa.Column('origin', sa.String(), nullable=False),
     sa.Column('type', sa.String(), nullable=False),
     sa.Column('label_id', sa.Integer(), nullable=True),
-    sa.Column('data', sa.String(), nullable=True),
+    sa.Column('data_encoded', sa.String(), nullable=True),
     sa.ForeignKeyConstraint(['file_id'], ['file.id'], ondelete='CASCADE'),
     sa.ForeignKeyConstraint(['label_id'], ['label.id'], ondelete='SET NULL'),
     sa.PrimaryKeyConstraint('id')

+ 2 - 1
pycs/database/Database.py

@@ -153,7 +153,8 @@ class Database:
         """
         # prepare some values
 
-        return Project.new(commit=commit,
+        return Project.new(
+            commit=commit,
             name=name,
             description=description,
             model=model,

+ 9 - 5
pycs/database/File.py

@@ -59,6 +59,10 @@ class File(NamedBaseModel):
         "collection_id",
     )
 
+    @property
+    def filename(self):
+        return f"{self.name}{self.extension}"
+
     def set_collection(self, id: T.Optional[int]):
         """
         set this file's collection
@@ -137,12 +141,12 @@ class File(NamedBaseModel):
 
 
     def create_result(self, origin, result_type, label, data: T.Optional[dict] = None):
-        data = data if data is None else json.dumps(data)
-
-        result = Result.new(file=self,
+        result = Result.new(commit=False,
+                            file_id=self.id,
                             origin=origin,
-                            type=result_type,
-                            data=data)
+                            type=result_type)
+
+        result.data = data
 
         if label is not None:
             assert isinstance(label, (int, Label)), f"Wrong label type: {type(label)}"

+ 6 - 0
pycs/database/Label.py

@@ -48,6 +48,12 @@ class Label(NamedBaseModel):
         remote_side=[id],
     )
 
+    results = db.relationship("Result",
+        backref="label",
+        passive_deletes=True,
+        lazy="dynamic",
+    )
+
     serialize_only = (
         "id",
         "name",

+ 17 - 16
pycs/database/Result.py

@@ -21,7 +21,7 @@ class Result(BaseModel):
         db.ForeignKey("label.id", ondelete="SET NULL"),
         nullable=True)
 
-    data = db.Column(db.String)
+    data_encoded = db.Column(db.String)
 
     serialize_only = (
         "id",
@@ -29,14 +29,27 @@ class Result(BaseModel):
         "origin",
         "type",
         "label_id",
-        "data",
+        "data_encoded",
     )
 
+    @property
+    def data(self):
+        return None if self.data_encoded is None else json.loads(self.data_encoded)
+
+    @data.setter
+    def data(self, value):
+        if isinstance(value, str) or value is None:
+            self.data_encoded = value
+
+        elif isinstance(value, (dict, list)):
+            self.data_encoded = json.dumps(value)
+
+        else:
+            raise ValueError(f"Not supported type: {type(value)}")
 
     def serialize(self):
         result = super().serialize()
-        if result["data"] is not None:
-            result["data"] = json.loads(result["data"])
+        result["data"] = self.data
         return result
 
     def set_origin(self, origin: str):
@@ -59,15 +72,3 @@ class Result(BaseModel):
         """
         self.label_id = label
         self.commit()
-
-    def set_data(self, data: T.Optional[dict]):
-        """
-        set this results data object
-
-        :param data: data object
-        :return:
-        """
-        data = data if data is None else json.dumps(data)
-
-        self.data = data
-        self.commit()

+ 2 - 2
pycs/database/base.py

@@ -44,12 +44,12 @@ class BaseModel(db.Model, ModelSerializer):
             self.commit()
 
     @classmethod
-    def new(cls, commit=False, **kwargs):
+    def new(cls, commit=True, **kwargs):
         obj = cls(**kwargs)
         db.session.add(obj)
 
         if commit:
-            db.session.commit()
+            obj.commit()
         return obj
 
     @classmethod

+ 1 - 1
pycs/frontend/endpoints/results/EditResultData.py

@@ -31,7 +31,7 @@ class EditResultData(View):
 
         # start transaction and set label
         with self.db:
-            result.set_data(data['data'])
+            result.data = data['data']
             result.set_origin('user')
 
         self.nm.edit_result(result.id)

+ 1 - 1
pycs/frontend/endpoints/results/GetProjectResults.py

@@ -14,7 +14,7 @@ class GetProjectResults(View):
 
     def dispatch_request(self, project_id: int):
         # get project from database
-        project = Project.object.get(project_id)
+        project = Project.query.get(project_id)
         if project is None:
             return abort(404)
 

+ 1 - 1
pycs/interfaces/MediaFile.py

@@ -123,6 +123,6 @@ class MediaFile:
             'fps': self.fps,
             'path': self.path,
             'filename': self.__file.name + self.__file.extension,
-            'results': list(map(lambda r: r.serialize(), self.results)),
+            'results': list(map(lambda r: r.serialize(), self.results())),
             'predictions': list(map(lambda r: r.serialize(), self.predictions())),
         }

+ 69 - 5
test/test_client.py

@@ -11,16 +11,29 @@ from pycs.database.Project import Project
 
 class ClientTests(BaseTestCase):
 
+    def _do_request(self, request_func, *args,  status_code=200, **kwargs):
+        response = request_func(*args, follow_redirects=True, **kwargs)
+        self.assertEqual(response.status_code, status_code, response.get_data().decode())
+        return response
+
+
     def _post(self, url, status_code=200, content_type=None, json=None, data=None):
-        response = self.client.post(url,
+        return self._do_request(self.client.post,
+            url,
+            status_code=status_code,
             json=json,
             data=data,
-            follow_redirects=True,
             content_type=content_type,
         )
 
-        self.assertEqual(response.status_code, 200, response.get_data().decode())
-        return response
+    def _get(self, url, status_code=200, content_type=None, json=None, data=None):
+        return self._do_request(self.client.get,
+            url,
+            status_code=status_code,
+            json=json,
+            data=data,
+            content_type=content_type,
+        )
 
     def test_project_creation(self):
 
@@ -114,7 +127,7 @@ class ClientTests(BaseTestCase):
         self._post(f"data/{file_id}/results",
             json=dict(
                 type="bounding-box",
-                data=dict(x0=0, x1=0, y0=0, y1=0),
+                data=dict(x=0, y=0, w=0, h=0),
                 label=2,
             )
         )
@@ -137,7 +150,58 @@ class ClientTests(BaseTestCase):
         self.assertEqual(0, Result.query.count())
 
 
+    def test_result_download(self):
+
+        self._post("/projects",
+            json=dict(
+                name="some name",
+                description="some description",
+                model=1,
+                label=2,
+                external=None,
+            )
+        )
+        project = Project.query.first()
+        project_id = project.id
+
+        self.wait_for_coroutines()
+        self._post(f"/projects/{project_id}/data",
+            data=dict(file=(io.BytesIO(b"some content"), "image.jpg")),
+            content_type="multipart/form-data",
+        )
+        file = File.query.first()
+        file_id = file.id
+
+        self.wait_for_coroutines()
+        self._post(f"data/{file_id}/results",
+            json=dict(
+                type="bounding-box",
+                data=dict(x=0, y=0, w=0, h=0),
+                label=2,
+            )
+        )
+
+        self.wait_for_coroutines()
+        eventlet.sleep(3)
+
+        response = self._get(f"/projects/{project_id}/results")
+
+        self.assertTrue(response.is_json)
+
+        file = File.query.first()
+        result = Result.query.first()
+
+        self.assertEqual(1, len(response.json))
+        returned_file = response.json[0]
+        self.assertEqual(returned_file["filename"], file.filename)
 
+        self.assertEqual(1, len(returned_file["results"]))
+        returned_result = returned_file["results"][0]
 
+        xywh_should = [result.data.get(attr) for attr in "xywh"]
+        xywh_is = [returned_result.get(attr) for attr in "xywh"]
+        self.assertListEqual(xywh_is, xywh_should)
 
+        self.assertEqual(returned_result["label"]["id"], result.label.id)
+        self.assertEqual(returned_result["label"]["name"], result.label.name)
 

+ 2 - 0
test/test_database.py

@@ -20,6 +20,7 @@ class DatabaseTests(BaseTestCase):
             for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
 
                 model = Model.new(
+                    commit=False,
                     name=f"Model {i}",
                     description=f"Description for Model {i}",
                     root_folder=f"modeldir{i}",
@@ -30,6 +31,7 @@ class DatabaseTests(BaseTestCase):
                     continue
 
                 provider = LabelProvider.new(
+                    commit=False,
                     name=f"Label Provider {i}",
                     description=f"Description for Label Provider {i}",
                     root_folder=f"labeldir{i}",