소스 검색

Some fixes according to the newest changes

Dimitri Korsch 2 년 전
부모
커밋
a58a813c20

+ 2 - 1
pycs/__init__.py

@@ -39,7 +39,8 @@ app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
 # Protect via http basic authentication
 app.config['FLASK_HTPASSWD_PATH'] = '.htpasswd'
 if not os.path.isfile(app.config['FLASK_HTPASSWD_PATH']):
-    raise FileNotFoundError("You need to specify a .htpasswd-file. The following file could not be located: " + app.config['FLASK_HTPASSWD_PATH'] + "!")
+    raise FileNotFoundError("You need to specify a .htpasswd-file."
+        f"The following file could not be located: {app.config['FLASK_HTPASSWD_PATH']}!")
 app.config['FLASK_SECRET'] = 'Hey Hey Kids, secure me!'
 htpasswd = HtPasswdAuth(app)
 

+ 1 - 1
pycs/database/File.py

@@ -147,7 +147,7 @@ class File(NamedBaseModel):
 
         annot_query = File.results.any()
 
-        if with_annotations == False:
+        if with_annotations is False:
             annot_query = ~annot_query
 
         return result.filter(annot_query)

+ 35 - 7
pycs/database/Project.py

@@ -4,6 +4,7 @@ import typing as T
 
 from datetime import datetime
 from eventlet import tpool
+from sqlalchemy.sql import case
 
 from pycs import app
 from pycs import db
@@ -166,7 +167,7 @@ class Project(NamedBaseModel):
         return label, is_new
 
     @commit_on_return
-    def bulk_create_labels(self, labels: T.List[T.Dict], clean_old_labels: bool = True):
+    def bulk_create_labels(self, labels: T.List[T.Dict]):
         """
             Inserts a all labels at once.
 
@@ -177,15 +178,42 @@ class Project(NamedBaseModel):
         if len(labels) == 0:
             return labels
 
-        if clean_old_labels:
-            self.labels.delete()
-
         for label in labels:
             label["project_id"] = self.id
 
         self.__check_labels(labels)
-        app.logger.info(f"Inserting {len(labels):,d} labels")
-        db.engine.execute(Label.__table__.insert(), labels)
+
+
+        # first update existing labels
+        fields_to_update = (
+            ("name", Label.name),
+            ("hierarchy_level", Label.hierarchy_level),
+        )
+
+        updates = {
+            field: case(
+                {lab["reference"]: lab[key] for lab in labels},
+                value=Label.reference)
+
+            for key, field in fields_to_update
+        }
+
+        existing_labs = self.labels.filter(
+            Label.reference.in_([lab["reference"] for lab in labels])
+        )
+        app.logger.info(f"Updating {existing_labs.count():,d} labels")
+        existing_labs.update(updates, synchronize_session=False)
+
+        # then add new labels
+        references = {lab.reference for lab in self.labels.all()}
+        new_labels = [lab for lab in labels
+            if lab["reference"] not in references]
+
+        if len(new_labels) > 0:
+            app.logger.info(f"Inserting {len(new_labels):,d} new labels")
+            db.engine.execute(Label.__table__.insert(), new_labels)
+
+        # finally set parents correctly
         self.__set_parents(labels)
 
         return labels
@@ -304,7 +332,7 @@ class Project(NamedBaseModel):
         if with_annotations is not None:
             annot_query = File.results.any()
 
-            if with_annotations == False:
+            if with_annotations is False:
                 annot_query = ~annot_query
 
             filters = filters + (annot_query,)

+ 7 - 0
pycs/database/__init__.py

@@ -0,0 +1,7 @@
+from pycs.database.Collection import Collection
+from pycs.database.File import File
+from pycs.database.Label import Label
+from pycs.database.LabelProvider import LabelProvider
+from pycs.database.Model import Model
+from pycs.database.Project import Project
+from pycs.database.Result import Result

+ 18 - 29
pycs/frontend/endpoints/pipelines/EstimateBoundingBox.py

@@ -1,14 +1,14 @@
-import cv2
+import typing as T
 import uuid
+
+import cv2
 import numpy as np
-import typing as T
 
 from flask import abort
 from flask import make_response
 from flask import request
 from flask.views import View
 
-from pycs import db
 from pycs.database.File import File
 from pycs.database.Result import Result
 from pycs.frontend.notifications.NotificationManager import NotificationManager
@@ -34,7 +34,7 @@ class EstimateBoundingBox(View):
         if 'x' not in request_data or 'y' not in request_data:
             abort(400, "coordinates for the estimation are missing")
 
-        x,y = map(request_data.get, "xy")
+        x, y = map(request_data.get, "xy")
 
         # get project
         project = file.project
@@ -56,13 +56,15 @@ class EstimateBoundingBox(View):
 
 
 def estimate(file_id: int, x: float, y: float) -> Result:
+    """ estimation function """
+
     file = File.query.get(file_id)
 
-    im = cv2.imread(file.absolute_path, cv2.IMREAD_GRAYSCALE)
+    image = cv2.imread(file.absolute_path, cv2.IMREAD_GRAYSCALE)
 
-    h, w = im.shape
+    h, w = image.shape
     pos = int(x * w), int(y * h)
-    x0, y0, x1, y1 = detect(im, pos,
+    x0, y0, x1, y1 = detect(image, pos,
                             window_size=1000,
                             pixel_delta=50,
                             enlarge=1e-2,
@@ -77,18 +79,19 @@ def estimate(file_id: int, x: float, y: float) -> Result:
 
     return file.create_result('pipeline', 'bounding-box', label=None, data=data)
 
-def detect(im: np.ndarray,
+def detect(image: np.ndarray,
            pos: T.Tuple[int, int],
            window_size: int = 1000,
            pixel_delta: int = 0,
            enlarge: float = -1) -> T.Tuple[int, int, int, int]:
-    # im = blur(im, 3)
+    """ detection function """
+    # image = blur(image, 3)
     x, y = pos
-    pixel = im[y, x]
+    pixel = image[y, x]
 
     min_pix, max_pix = pixel - pixel_delta, pixel + pixel_delta
 
-    mask = np.logical_and(min_pix < im, im < max_pix).astype(np.float32)
+    mask = np.logical_and(min_pix < image, image < max_pix).astype(np.float32)
     # mask = open_close(mask)
     # mask = blur(mask)
 
@@ -98,18 +101,19 @@ def detect(im: np.ndarray,
 
     sum_x, sum_y = window.sum(axis=0), window.sum(axis=1)
 
-    enlarge = int(enlarge * max(im.shape))
+    enlarge = int(enlarge * max(image.shape))
     (x0, x1), (y0, y1) = get_borders(sum_x, enlarge), get_borders(sum_y, enlarge)
 
     x0 = max(x + x0 - pad, 0)
     y0 = max(y + y0 - pad, 0)
 
-    x1 = min(x + x1 - pad, im.shape[1])
-    y1 = min(y + y1 - pad, im.shape[0])
+    x1 = min(x + x1 - pad, image.shape[1])
+    y1 = min(y + y1 - pad, image.shape[0])
 
     return x0, y0, x1, y1
 
 def get_borders(arr, enlarge: int, eps=5e-1):
+    """ returns borders based on coordinate extrema """
     mid = len(arr) // 2
 
     arr0, arr1 = arr[:mid], arr[mid:]
@@ -130,18 +134,3 @@ def get_borders(arr, enlarge: int, eps=5e-1):
         upper = min(upper + enlarge, len(arr)-1)
 
     return int(lower), int(upper)
-
-
-"""
-def blur(im, sigma=5):
-    from skimage import filters
-    return filters.gaussian(im, sigma=sigma, preserve_range=True)
-
-def open_close(im, kernel_size=3):
-
-    kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
-
-    im = cv2.morphologyEx(im, cv2.MORPH_OPEN, kernel)
-    im = cv2.morphologyEx(im, cv2.MORPH_CLOSE, kernel)
-    return im
-"""

+ 0 - 1
pycs/frontend/endpoints/pipelines/PredictModel.py

@@ -12,7 +12,6 @@ from pycs.database.Result import Result
 from pycs.frontend.notifications.NotificationList import NotificationList
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.interfaces.MediaFile import MediaFile
-from pycs.interfaces.MediaLabel import MediaLabel
 from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException

+ 0 - 2
pycs/frontend/endpoints/results/CopyResults.py

@@ -1,12 +1,10 @@
 from flask import abort
-from flask import jsonify
 from flask import make_response
 from flask import request
 from flask.views import View
 
 from pycs import db
 from pycs.database.File import File
-from pycs.database.File import Result
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 

+ 173 - 3
pycs/management/result.py

@@ -1,10 +1,11 @@
 import click
 import flask
+import simplejson as json
 
 from flask.cli import AppGroup
 
 from pycs import app
-from pycs.database.Project import Project
+from pycs import database as db
 
 result_cli = AppGroup("result", short_help="Result operations")
 
@@ -17,13 +18,13 @@ result_cli = AppGroup("result", short_help="Result operations")
 def export(project_id, output, indent):
     """ Export results for a specific project or for all projects """
     if project_id == "all":
-        projects = Project.query.all()
+        projects = db.Project.query.all()
         app.logger.info(f"Exporting results for all projects ({len(projects)})!")
         if output is None:
             output = "output.json"
 
     else:
-        project = Project.query.get(project_id)
+        project = db.Project.query.get(project_id)
         if project is None:
             app.logger.error(f"Could not find project with ID {project_id}!")
             return
@@ -58,3 +59,172 @@ def export(project_id, output, indent):
 
     with open(output, "w", encoding="utf-8") as out_f:
         flask.json.dump(results, out_f, app=app, indent=indent)
+
+
+
+
+@result_cli.command("restore")
+@click.argument("infile")
+@click.option("--dry-run", is_flag=True)
+def restore(infile, dry_run):
+
+    with open(infile) as f:
+        results = json.load(f)
+
+    for project_results in results:
+        project = db.Project.get_or_404(project_results["project_id"])
+        for file_results in project_results["files"]:
+            file = db.File.get_or_404(file_results["id"])
+
+            assert file.path == file_results["path"]
+
+            # first check for new and changed results
+            for _result in file_results["results"]:
+
+                if not _is_data_valid(**_result):
+                    continue
+
+                result = get_result_or_none(file, **_result)
+
+                user1 = _result["origin_user"]
+                data1 = _result["data"]
+                ref1 = (_result["label"] or {}).get("reference")
+                # lab1 = (_result["label"] or {}).get("id")
+
+
+                if result is None:
+                    # we have a new result entry
+                    if not dry_run:
+                        file.create_result(
+                            result_type="bounding-box",
+                            origin="user",
+                            origin_user=user1,
+                            label=ref1,
+                            data=data1,
+                            commit=True
+                        )
+                    print(" | ".join([
+                        f"Project #{project.id:< 6d}"
+                        f"File #{file.id:< 6d} [{file.name:^30s}]",
+                        "[New Result]",
+                        f"User: {user1 or '':<10s}",
+                        f"Data: {data1}, Label-Ref: {ref1}",
+                        ])
+                    )
+
+                    continue
+
+                assert result.file_id == _result["file_id"]
+                user0 = result.origin_user
+                data0 = result.data
+                ref0 = getattr(result.label, "reference", None)
+                # lab0 = getattr(result.label, "id", None)
+
+                is_same_data = _check_data(data0, data1)
+
+                if is_same_data and (ref0 == ref1 or ref1 is None):
+                    # nothing to change
+                    continue
+
+                print(" | ".join([
+                    f"Project #{project.id:< 6d}"
+                    f"File #{file.id:< 6d} [{file.name:^30s}]",
+                    ]), end=" | "
+                )
+                if not is_same_data:
+                    # data was updated
+                    print(" | ".join([
+                        "[Data updated]",
+                        f"User: {user1 or '':<10s}",
+                        f"Data: {data0} -> {data1}"
+                        ]), end=" | "
+                    )
+                    assert user1 is not None
+                    if not dry_run:
+                        result.origin_user = user1
+                        result.data = data1
+
+                if ref0 != ref1:
+                    assert user1 is not None
+                    if not dry_run:
+                        result.origin_user = user1
+                    if ref1 is None:
+                        # label was deleted
+                        print("[Label Deleted]")
+                        if not dry_run:
+                            result.label_id = None
+                    else:
+                        # label was updated
+                        print(" | ".join([
+                            "[Label updated]",
+                            f"User: {user0 or '':<10s} -> {user1 or '':<10s}",
+                            f"{ref0 or 'UNK':<6s} -> {ref1 or 'UNK':<6s}"
+                            ])
+                        )
+                        label = project.label_by_reference(ref1)
+                        if not dry_run:
+                            result.label_id = label.id
+                else:
+                    print()
+
+                if not dry_run:
+                    result.commit()
+
+            # then check for deleted results
+            for result in file.results.all():
+                if result.origin != "user" or result.type != "bounding-box":
+                    continue
+
+                found = False
+                for _result in file_results["results"]:
+                    if not _is_data_valid(**_result):
+                        continue
+
+                    if _check_data(result.data, _result["data"]):
+                        found = True
+                        break
+
+                if not found:
+                    print(" | ".join([
+                        f"Project #{project.id:< 6d}"
+                        f"File #{file.id:< 6d} [{file.name:^30s}]",
+                        "[Result deleted]",
+                        f"{result.data}",
+                        f"{result.label}",
+                        ])
+                    )
+
+                    if not dry_run:
+                        result.delete()
+
+def _is_data_valid(*, data, type, origin, **kwargs):
+
+    wh = (None, None) if data is None else (data["w"], data["h"])
+
+    return (type != "labeled-image" and
+        origin == "user" and
+        0 not in wh)
+
+def _check_data(data0, data1):
+
+    if None in (data0, data1):
+        return data0 == data1 == None
+
+    for key in data0:
+        if data1.get(key) != data0.get(key):
+            return False
+
+    return True
+
+def get_result_or_none(file: db.File, id: int, data: dict, **kwargs):
+
+    result = db.Result.query.filter(
+        db.Result.id==id, db.Result.file_id==file.id).one_or_none()
+
+    if result is not None:
+        return result
+
+    for other_results in file.results.all():
+        if _check_data(data, other_results.data):
+            # import pdb; pdb.set_trace()
+            return other_results

+ 1 - 0
requirements.txt

@@ -9,6 +9,7 @@ flask-sqlalchemy
 sqlalchemy_serializer
 flask-migrate
 flask-htpasswd
+itsdangerous~=2.0.1
 python-socketio
 munch
 scikit-image

+ 26 - 3
tests/client/pipeline_tests.py

@@ -179,10 +179,33 @@ class LabelProviderPipelineTests:
     def test_label_loading_multiple(self):
 
         for i in range(3):
-            self.post(self.url, json=dict(execute=True))
-            self.wait_for_bg_jobs()
+            self.test_label_loading()
+
+    def test_multiple_loading_does_not_delete_existing_labels(self):
+        self.test_label_loading()
+
+        file = self.project.files.first()
+
+        def _check():
+            for res in file.results.all():
+                self.assertIsNotNone(res.label_id)
+
+        for label in self.project.labels:
+            file.create_result(
+                origin="user",
+                result_type="bounding-box",
+                label=label,
+                data=dict(x=0, y=0, w=0.2, h=0.3),
+            )
+
+        file.commit()
+
+        _check()
+
+        for i in range(3):
+            self.test_label_loading()
+            _check()
 
-            self.assertEqual(self.n_labels, self.project.labels.count())
 
 class SimpleLabelProviderPipelineTests(LabelProviderPipelineTests, _BasePipelineTests):