6
0
Эх сурвалжийг харах

added label creation and removal tests

Dimitri Korsch 3 жил өмнө
parent
commit
14830c0dfe

+ 6 - 5
pycs/database/base.py

@@ -20,7 +20,10 @@ class BaseModel(db.Model, SerializerMixin):
 
 
     id = db.Column(db.Integer, primary_key=True)
     id = db.Column(db.Integer, primary_key=True)
 
 
-    serialize_only = ("id",)
+    serialize_only = ("id", "identifier")
+
+    def identifier(self):
+        return self.id
 
 
     def __repr__(self):
     def __repr__(self):
         attrs = self.serialize()
         attrs = self.serialize()
@@ -29,10 +32,8 @@ class BaseModel(db.Model, SerializerMixin):
 
 
 
 
     def serialize(self) -> dict:
     def serialize(self) -> dict:
-        """ default model serialize method. adds identifier as alias for id """
-        res = self.to_dict()
-        res["identifier"] = self.id
-        return res
+        """ default model serialization method"""
+        return self.to_dict()
 
 
 
 
     @commit_on_return
     @commit_on_return

+ 6 - 1
pycs/frontend/endpoints/labels/CreateLabel.py

@@ -29,6 +29,7 @@ class CreateLabel(View):
             abort(400)
             abort(400)
 
 
         name = data['name']
         name = data['name']
+        reference = data.get('reference')
         parent = data.get('parent')
         parent = data.get('parent')
 
 
         # find project
         # find project
@@ -36,7 +37,11 @@ class CreateLabel(View):
 
 
         # insert label
         # insert label
         with db.session.begin_nested():
         with db.session.begin_nested():
-            label, is_new = project.create_label(name, parent=parent, commit=False)
+            label, is_new = project.create_label(
+                name=name,
+                reference=reference,
+                parent=parent,
+                commit=False)
             if not is_new:
             if not is_new:
                 abort(400, f"Label already exists: {label}")
                 abort(400, f"Label already exists: {label}")
 
 

+ 7 - 7
pycs/frontend/endpoints/labels/RemoveLabel.py

@@ -37,16 +37,16 @@ class RemoveLabel(View):
         # start transaction
         # start transaction
         with db.session.begin_nested():
         with db.session.begin_nested():
 
 
-            # update children's parent entry
-            Label.query.filter(Label.parent == self).update({Label.parent: label.parent},
-                synchronize_session=False)
-
-            # remove label
+            children = label.children
+            # remove label. because of cascade, parent_id is set ot NULL
             label_dump = label.delete(commit=False)
             label_dump = label.delete(commit=False)
+            label.flush()
 
 
-            # notify about changes
-            for child in children:
+            # notify about changes and reset the parent
+            for child in label.children:
+                child.parent_id = label.parent_id
                 self.nm.edit_label(child)
                 self.nm.edit_label(child)
+
             self.nm.remove_label(label_dump)
             self.nm.remove_label(label_dump)
 
 
         # return success response
         # return success response

+ 2 - 0
tests/__init__.py

@@ -0,0 +1,2 @@
+from tests.test_database import *
+from tests.client.label_tests import *

+ 48 - 0
tests/base.py

@@ -2,6 +2,7 @@
 import os
 import os
 import shutil
 import shutil
 import unittest
 import unittest
+import typing as T
 
 
 from pycs import app
 from pycs import app
 from pycs import db
 from pycs import db
@@ -13,6 +14,7 @@ from pycs.database.LabelProvider import LabelProvider
 server = None
 server = None
 
 
 class BaseTestCase(unittest.TestCase):
 class BaseTestCase(unittest.TestCase):
+
     def setUp(self, discovery: bool = False):
     def setUp(self, discovery: bool = False):
         global server
         global server
         app.config["TESTING"] = True
         app.config["TESTING"] = True
@@ -24,6 +26,8 @@ class BaseTestCase(unittest.TestCase):
         db.create_all()
         db.create_all()
 
 
         self.client = app.test_client()
         self.client = app.test_client()
+        self.context = app.test_request_context()
+        self.context.push()
 
 
         # init the server once
         # init the server once
         if server is None:
         if server is None:
@@ -34,10 +38,54 @@ class BaseTestCase(unittest.TestCase):
             Model.discover("models/")
             Model.discover("models/")
             LabelProvider.discover("labels/")
             LabelProvider.discover("labels/")
 
 
+        self.setupModels()
+
+    def setupModels(self):
+        pass
+
     def tearDown(self):
     def tearDown(self):
+        self.context.pop()
 
 
         if os.path.exists(self.projects_dir):
         if os.path.exists(self.projects_dir):
             shutil.rmtree(self.projects_dir)
             shutil.rmtree(self.projects_dir)
 
 
         db.drop_all()
         db.drop_all()
 
 
+
+    def _do_request(self, request_func, *args,  status_code=200, **kwargs):
+        response = request_func(*args, follow_redirects=True, **kwargs)
+        self.assertEqual(response.status_code, status_code, response.get_data().decode())
+        return response
+
+    def post(self,
+             url: str,
+             *,
+             status_code: int = 200,
+             content_type: T.Optional[str] = None,
+             data: T.Optional[dict] = None,
+             json: T.Optional[dict] = None):
+
+        return self._do_request(
+            self.client.post,
+            url,
+            status_code=status_code,
+            json=json,
+            data=data,
+            content_type=content_type,
+        )
+
+    def get(self,
+             url: str,
+             *,
+             status_code: int = 200,
+             data: T.Optional[dict] = None,
+             json: T.Optional[dict] = None):
+
+        return self._do_request(
+            self.client.get,
+            url,
+            status_code=status_code,
+            json=json,
+            data=data,
+            content_type=content_type,
+        )

+ 1 - 0
tests/client/__init__.py

@@ -0,0 +1 @@
+from tests.client.label_tests import *

+ 215 - 0
tests/client/label_tests.py

@@ -0,0 +1,215 @@
+from flask import url_for
+
+
+from pycs import db
+from pycs.database.Label import Label
+from pycs.database.Model import Model
+from pycs.database.Project import Project
+
+from tests.base import BaseTestCase
+
+
+class _BaseLabelTests(BaseTestCase):
+
+    def setupModels(self):
+
+        model = Model.new(
+            commit=False,
+            name="TestModel",
+            description="Model for a test case",
+            root_folder="model_folder",
+        )
+        model.supports = ["labeled-image"]
+        model.flush()
+
+        self.project = Project.new(
+            name="test_project",
+            description="Project for a test case",
+            model=model,
+            root_folder="project_folder",
+            external_data=False,
+            data_folder="project_folder/data",
+        )
+
+class LabelCreationTests(_BaseLabelTests):
+
+    def setUp(self):
+        super().setUp()
+        self.url = url_for("create_label", project_id=self.project.id)
+
+    def test_create_labels(self):
+        self.assertEqual(0, self.project.labels.count())
+
+        for i in range(1, 11):
+            self.post(self.url, json=dict(name=f"Label{i}"))
+            self.assertEqual(self.project.labels.count(), i)
+            last_label = Label.query.order_by(Label.id.desc()).first()
+            self.assertEqual(last_label.name, f"Label{i}")
+
+    def test_create_label_with_same_name(self):
+        self.assertEqual(0, self.project.labels.count())
+
+        name = "SameName"
+        for i in range(1, 11):
+            self.post(self.url, json=dict(name=name))
+            self.assertEqual(self.project.labels.count(), i)
+            last_label = Label.query.order_by(Label.id.desc()).first()
+            self.assertEqual(last_label.name, name)
+
+    def test_create_labels_with_reference(self):
+        self.assertEqual(0, self.project.labels.count())
+
+        for i in range(1, 11):
+            self.post(self.url, json=dict(name=f"Label{i}", reference=f"label_ref{i}"))
+            self.assertEqual(self.project.labels.count(), i)
+            last_label = Label.query.order_by(Label.id.desc()).first()
+            self.assertEqual(last_label.name, f"Label{i}")
+            self.assertEqual(last_label.reference, f"label_ref{i}")
+
+
+    def test_create_labels_with_same_reference(self):
+        self.assertEqual(0, self.project.labels.count())
+
+        ref = "same_ref"
+        self.post(self.url, json=dict(name=f"Label", reference=ref))
+        last_label = Label.query.order_by(Label.id.desc()).first()
+        self.assertEqual(last_label.name, f"Label")
+        self.assertEqual(last_label.reference, ref)
+
+        for i in range(2, 11):
+            self.post(self.url, status_code=400, json=dict(name=f"Label{i}", reference=ref))
+            self.assertEqual(self.project.labels.count(), 1)
+            last_label = Label.query.order_by(Label.id.desc()).first()
+            self.assertEqual(last_label.name, f"Label")
+            self.assertEqual(last_label.reference, ref)
+
+
+class LabelRemovalTests(_BaseLabelTests):
+
+    def url(self, label_id):
+        return url_for("remove_label",
+            project_id=self.project.id, label_id=label_id)
+
+    def setupModels(self):
+        super().setupModels()
+        self.N = 10
+        self.labels = []
+        with db.session.begin_nested():
+            for i in range(1, self.N+1):
+                label, is_new = self.project.create_label(
+                    commit=False,
+                    name=f"Label{i}",
+                    reference=f"Label{i}",
+                )
+                assert is_new, "labels should be new"
+                label.flush()
+                self.labels.append(label)
+
+    def test_remove_label_without_flag(self):
+        self.assertEqual(self.N, self.project.labels.count())
+
+        self.post(self.url(self.labels[0].id), status_code=400)
+        self.assertEqual(self.N, self.project.labels.count())
+
+        self.post(self.url(self.labels[0].id), status_code=400, json=dict(remove=False))
+        self.assertEqual(self.N, self.project.labels.count())
+
+
+    def test_remove_label(self):
+        self.assertEqual(self.N, self.project.labels.count())
+
+        self.post(self.url(self.labels[0].id), json=dict(remove=True))
+        self.assertEqual(self.N-1, self.project.labels.count())
+
+
+class LabelTreeRemovalTests(_BaseLabelTests):
+
+    def url(self, label_id):
+        return url_for("remove_label",
+            project_id=self.project.id, label_id=label_id)
+
+    def setupModels(self):
+        super().setupModels()
+        self.n_roots = 5
+        self.leafs_per_root = 3
+        self.n_leafs = self.n_roots * self.leafs_per_root
+        self.N = self.n_roots + self.n_leafs
+
+        with db.session.begin_nested():
+            for i in range(1, self.n_roots+1):
+                root_label, is_new = self.project.create_label(
+                    commit=False,
+                    name=f"Label{i}",
+                    reference=f"Label{i}",
+                )
+                root_label.flush()
+
+                for j in range(1, self.leafs_per_root+1):
+                    label, is_new = self.project.create_label(
+                        commit=False,
+                        name=f"Label{i}_{j}",
+                        reference=f"Label{i}_{j}",
+                        parent=root_label
+                    )
+                    label.flush()
+
+
+    @property
+    def root_labels(self):
+        return self.project.labels.filter(Label.parent_id == None)
+
+    @property
+    def leaf_labels(self):
+        return self.project.labels.filter(Label.parent_id != None)
+
+    def test_remove_root_label(self):
+        self.assertEqual(self.N, self.project.labels.count())
+        self.assertEqual(self.n_roots, self.root_labels.count())
+        self.assertEqual(self.n_leafs, self.leaf_labels.count())
+
+        label = self.root_labels.first()
+
+        self.post(self.url(label.id), json=dict(remove=True))
+
+        self.assertEqual(self.N-1, self.project.labels.count())
+        self.assertEqual(self.n_roots-1+self.leafs_per_root, self.root_labels.count())
+        self.assertEqual(self.n_leafs-self.leafs_per_root, self.leaf_labels.count())
+
+    def test_remove_leaf_label(self):
+        self.assertEqual(self.N, self.project.labels.count())
+        self.assertEqual(self.n_roots, self.root_labels.count())
+        self.assertEqual(self.n_leafs, self.leaf_labels.count())
+
+        label = self.leaf_labels.first()
+        self.post(self.url(label.id), json=dict(remove=True))
+
+        self.assertEqual(self.N-1, self.project.labels.count())
+        self.assertEqual(self.n_roots, self.root_labels.count())
+        self.assertEqual(self.n_leafs-1, self.leaf_labels.count())
+
+    def test_remove_intermediate_label(self):
+        self.assertEqual(self.N, self.project.labels.count())
+        self.assertEqual(self.n_roots, self.root_labels.count())
+        self.assertEqual(self.n_leafs, self.leaf_labels.count())
+
+        label = self.leaf_labels.first()
+        label_parent = label.parent
+
+        new_label, is_new = self.project.create_label(
+            commit=False,
+            name="New Label",
+            reference="new_ref",
+            parent=label,
+        )
+
+        self.assertTrue(is_new)
+        self.assertEqual(label.id, new_label.parent_id)
+        self.assertEqual(self.n_roots, self.root_labels.count())
+        self.assertEqual(self.n_leafs + 1, self.leaf_labels.count())
+
+        self.post(self.url(label.id), json=dict(remove=True))
+
+        self.assertEqual(label_parent.id, new_label.parent_id)
+        self.assertEqual(self.n_roots, self.root_labels.count())
+        self.assertEqual(self.n_leafs, self.leaf_labels.count())
+

+ 1 - 2
tests/test_database.py

@@ -12,9 +12,8 @@ from tests.base import BaseTestCase
 
 
 
 
 class TestDatabase(BaseTestCase):
 class TestDatabase(BaseTestCase):
-    def setUp(self) -> None:
-        super().setUp()
 
 
+    def setupModels(self):
         with db.session.begin_nested():
         with db.session.begin_nested():
 
 
             for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
             for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):

+ 8 - 4
webui/src/components/other/LabelTreeView.vue

@@ -5,8 +5,8 @@
        draggable="true" @dragstart="dragstart" @dragend="dragend"
        draggable="true" @dragstart="dragstart" @dragend="dragend"
        @dragover="dragover" @dragleave="dragleave" @drop="drop">
        @dragover="dragover" @dragleave="dragleave" @drop="drop">
     <editable-headline :value="label.name"
     <editable-headline :value="label.name"
-                       @change="editLabel(label.identifier, $event)"
-                       @remove="removeLabel(label.identifier)">
+                       @change="editLabel($event)"
+                       @remove="removeLabel()">
       <div class="hierarchy"
       <div class="hierarchy"
            :class="{margined: label.hierarchy_level}"
            :class="{margined: label.hierarchy_level}"
            @click="collapse = !collapse">
            @click="collapse = !collapse">
@@ -43,6 +43,7 @@ export default {
   components: {EditableHeadline},
   components: {EditableHeadline},
   props: ['label', 'indent', 'targetable'],
   props: ['label', 'indent', 'targetable'],
   data: function () {
   data: function () {
+    console.log(this.label)
     return {
     return {
       untouched: true,
       untouched: true,
       target: false,
       target: false,
@@ -63,12 +64,15 @@ export default {
     }
     }
   },
   },
   methods: {
   methods: {
-    editLabel: function (id, value) {
+    editLabel: function (value) {
       // TODO then / error
       // TODO then / error
+      let id = this.label.identifier;
       this.$root.socket.post(`/projects/${this.$root.project.identifier}/labels/${id}/name`, {name: value});
       this.$root.socket.post(`/projects/${this.$root.project.identifier}/labels/${id}/name`, {name: value});
     },
     },
-    removeLabel: function (id) {
+    removeLabel: function () {
       // TODO then / error
       // TODO then / error
+      console.log(this.label)
+      let id = this.label.identifier;
       this.$root.socket.post(`/projects/${this.$root.project.identifier}/labels/${id}/remove`, {remove: true});
       this.$root.socket.post(`/projects/${this.$root.project.identifier}/labels/${id}/remove`, {remove: true});
     },
     },
     dragstart: function (e) {
     dragstart: function (e) {