|
@@ -0,0 +1,215 @@
|
|
|
+from flask import url_for
|
|
|
+
|
|
|
+
|
|
|
+from pycs import db
|
|
|
+from pycs.database.Label import Label
|
|
|
+from pycs.database.Model import Model
|
|
|
+from pycs.database.Project import Project
|
|
|
+
|
|
|
+from tests.base import BaseTestCase
|
|
|
+
|
|
|
+
|
|
|
+class _BaseLabelTests(BaseTestCase):
|
|
|
+
|
|
|
+ def setupModels(self):
|
|
|
+
|
|
|
+ model = Model.new(
|
|
|
+ commit=False,
|
|
|
+ name="TestModel",
|
|
|
+ description="Model for a test case",
|
|
|
+ root_folder="model_folder",
|
|
|
+ )
|
|
|
+ model.supports = ["labeled-image"]
|
|
|
+ model.flush()
|
|
|
+
|
|
|
+ self.project = Project.new(
|
|
|
+ name="test_project",
|
|
|
+ description="Project for a test case",
|
|
|
+ model=model,
|
|
|
+ root_folder="project_folder",
|
|
|
+ external_data=False,
|
|
|
+ data_folder="project_folder/data",
|
|
|
+ )
|
|
|
+
|
|
|
+class LabelCreationTests(_BaseLabelTests):
|
|
|
+
|
|
|
+ def setUp(self):
|
|
|
+ super().setUp()
|
|
|
+ self.url = url_for("create_label", project_id=self.project.id)
|
|
|
+
|
|
|
+ def test_create_labels(self):
|
|
|
+ self.assertEqual(0, self.project.labels.count())
|
|
|
+
|
|
|
+ for i in range(1, 11):
|
|
|
+ self.post(self.url, json=dict(name=f"Label{i}"))
|
|
|
+ self.assertEqual(self.project.labels.count(), i)
|
|
|
+ last_label = Label.query.order_by(Label.id.desc()).first()
|
|
|
+ self.assertEqual(last_label.name, f"Label{i}")
|
|
|
+
|
|
|
+ def test_create_label_with_same_name(self):
|
|
|
+ self.assertEqual(0, self.project.labels.count())
|
|
|
+
|
|
|
+ name = "SameName"
|
|
|
+ for i in range(1, 11):
|
|
|
+ self.post(self.url, json=dict(name=name))
|
|
|
+ self.assertEqual(self.project.labels.count(), i)
|
|
|
+ last_label = Label.query.order_by(Label.id.desc()).first()
|
|
|
+ self.assertEqual(last_label.name, name)
|
|
|
+
|
|
|
+ def test_create_labels_with_reference(self):
|
|
|
+ self.assertEqual(0, self.project.labels.count())
|
|
|
+
|
|
|
+ for i in range(1, 11):
|
|
|
+ self.post(self.url, json=dict(name=f"Label{i}", reference=f"label_ref{i}"))
|
|
|
+ self.assertEqual(self.project.labels.count(), i)
|
|
|
+ last_label = Label.query.order_by(Label.id.desc()).first()
|
|
|
+ self.assertEqual(last_label.name, f"Label{i}")
|
|
|
+ self.assertEqual(last_label.reference, f"label_ref{i}")
|
|
|
+
|
|
|
+
|
|
|
+ def test_create_labels_with_same_reference(self):
|
|
|
+ self.assertEqual(0, self.project.labels.count())
|
|
|
+
|
|
|
+ ref = "same_ref"
|
|
|
+ self.post(self.url, json=dict(name=f"Label", reference=ref))
|
|
|
+ last_label = Label.query.order_by(Label.id.desc()).first()
|
|
|
+ self.assertEqual(last_label.name, f"Label")
|
|
|
+ self.assertEqual(last_label.reference, ref)
|
|
|
+
|
|
|
+ for i in range(2, 11):
|
|
|
+ self.post(self.url, status_code=400, json=dict(name=f"Label{i}", reference=ref))
|
|
|
+ self.assertEqual(self.project.labels.count(), 1)
|
|
|
+ last_label = Label.query.order_by(Label.id.desc()).first()
|
|
|
+ self.assertEqual(last_label.name, f"Label")
|
|
|
+ self.assertEqual(last_label.reference, ref)
|
|
|
+
|
|
|
+
|
|
|
+class LabelRemovalTests(_BaseLabelTests):
|
|
|
+
|
|
|
+ def url(self, label_id):
|
|
|
+ return url_for("remove_label",
|
|
|
+ project_id=self.project.id, label_id=label_id)
|
|
|
+
|
|
|
+ def setupModels(self):
|
|
|
+ super().setupModels()
|
|
|
+ self.N = 10
|
|
|
+ self.labels = []
|
|
|
+ with db.session.begin_nested():
|
|
|
+ for i in range(1, self.N+1):
|
|
|
+ label, is_new = self.project.create_label(
|
|
|
+ commit=False,
|
|
|
+ name=f"Label{i}",
|
|
|
+ reference=f"Label{i}",
|
|
|
+ )
|
|
|
+ assert is_new, "labels should be new"
|
|
|
+ label.flush()
|
|
|
+ self.labels.append(label)
|
|
|
+
|
|
|
+ def test_remove_label_without_flag(self):
|
|
|
+ self.assertEqual(self.N, self.project.labels.count())
|
|
|
+
|
|
|
+ self.post(self.url(self.labels[0].id), status_code=400)
|
|
|
+ self.assertEqual(self.N, self.project.labels.count())
|
|
|
+
|
|
|
+ self.post(self.url(self.labels[0].id), status_code=400, json=dict(remove=False))
|
|
|
+ self.assertEqual(self.N, self.project.labels.count())
|
|
|
+
|
|
|
+
|
|
|
+ def test_remove_label(self):
|
|
|
+ self.assertEqual(self.N, self.project.labels.count())
|
|
|
+
|
|
|
+ self.post(self.url(self.labels[0].id), json=dict(remove=True))
|
|
|
+ self.assertEqual(self.N-1, self.project.labels.count())
|
|
|
+
|
|
|
+
|
|
|
+class LabelTreeRemovalTests(_BaseLabelTests):
|
|
|
+
|
|
|
+ def url(self, label_id):
|
|
|
+ return url_for("remove_label",
|
|
|
+ project_id=self.project.id, label_id=label_id)
|
|
|
+
|
|
|
+ def setupModels(self):
|
|
|
+ super().setupModels()
|
|
|
+ self.n_roots = 5
|
|
|
+ self.leafs_per_root = 3
|
|
|
+ self.n_leafs = self.n_roots * self.leafs_per_root
|
|
|
+ self.N = self.n_roots + self.n_leafs
|
|
|
+
|
|
|
+ with db.session.begin_nested():
|
|
|
+ for i in range(1, self.n_roots+1):
|
|
|
+ root_label, is_new = self.project.create_label(
|
|
|
+ commit=False,
|
|
|
+ name=f"Label{i}",
|
|
|
+ reference=f"Label{i}",
|
|
|
+ )
|
|
|
+ root_label.flush()
|
|
|
+
|
|
|
+ for j in range(1, self.leafs_per_root+1):
|
|
|
+ label, is_new = self.project.create_label(
|
|
|
+ commit=False,
|
|
|
+ name=f"Label{i}_{j}",
|
|
|
+ reference=f"Label{i}_{j}",
|
|
|
+ parent=root_label
|
|
|
+ )
|
|
|
+ label.flush()
|
|
|
+
|
|
|
+
|
|
|
+ @property
|
|
|
+ def root_labels(self):
|
|
|
+ return self.project.labels.filter(Label.parent_id == None)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def leaf_labels(self):
|
|
|
+ return self.project.labels.filter(Label.parent_id != None)
|
|
|
+
|
|
|
+ def test_remove_root_label(self):
|
|
|
+ self.assertEqual(self.N, self.project.labels.count())
|
|
|
+ self.assertEqual(self.n_roots, self.root_labels.count())
|
|
|
+ self.assertEqual(self.n_leafs, self.leaf_labels.count())
|
|
|
+
|
|
|
+ label = self.root_labels.first()
|
|
|
+
|
|
|
+ self.post(self.url(label.id), json=dict(remove=True))
|
|
|
+
|
|
|
+ self.assertEqual(self.N-1, self.project.labels.count())
|
|
|
+ self.assertEqual(self.n_roots-1+self.leafs_per_root, self.root_labels.count())
|
|
|
+ self.assertEqual(self.n_leafs-self.leafs_per_root, self.leaf_labels.count())
|
|
|
+
|
|
|
+ def test_remove_leaf_label(self):
|
|
|
+ self.assertEqual(self.N, self.project.labels.count())
|
|
|
+ self.assertEqual(self.n_roots, self.root_labels.count())
|
|
|
+ self.assertEqual(self.n_leafs, self.leaf_labels.count())
|
|
|
+
|
|
|
+ label = self.leaf_labels.first()
|
|
|
+ self.post(self.url(label.id), json=dict(remove=True))
|
|
|
+
|
|
|
+ self.assertEqual(self.N-1, self.project.labels.count())
|
|
|
+ self.assertEqual(self.n_roots, self.root_labels.count())
|
|
|
+ self.assertEqual(self.n_leafs-1, self.leaf_labels.count())
|
|
|
+
|
|
|
+ def test_remove_intermediate_label(self):
|
|
|
+ self.assertEqual(self.N, self.project.labels.count())
|
|
|
+ self.assertEqual(self.n_roots, self.root_labels.count())
|
|
|
+ self.assertEqual(self.n_leafs, self.leaf_labels.count())
|
|
|
+
|
|
|
+ label = self.leaf_labels.first()
|
|
|
+ label_parent = label.parent
|
|
|
+
|
|
|
+ new_label, is_new = self.project.create_label(
|
|
|
+ commit=False,
|
|
|
+ name="New Label",
|
|
|
+ reference="new_ref",
|
|
|
+ parent=label,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.assertTrue(is_new)
|
|
|
+ self.assertEqual(label.id, new_label.parent_id)
|
|
|
+ self.assertEqual(self.n_roots, self.root_labels.count())
|
|
|
+ self.assertEqual(self.n_leafs + 1, self.leaf_labels.count())
|
|
|
+
|
|
|
+ self.post(self.url(label.id), json=dict(remove=True))
|
|
|
+
|
|
|
+ self.assertEqual(label_parent.id, new_label.parent_id)
|
|
|
+ self.assertEqual(self.n_roots, self.root_labels.count())
|
|
|
+ self.assertEqual(self.n_leafs, self.leaf_labels.count())
|
|
|
+
|