6
0
Эх сурвалжийг харах

Some fixes according to the newest changes

Dimitri Korsch 2 жил өмнө
parent
commit
a58a813c20

+ 2 - 1
pycs/__init__.py

@@ -39,7 +39,8 @@ app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
 # Protect via http basic authentication
 # Protect via http basic authentication
 app.config['FLASK_HTPASSWD_PATH'] = '.htpasswd'
 app.config['FLASK_HTPASSWD_PATH'] = '.htpasswd'
 if not os.path.isfile(app.config['FLASK_HTPASSWD_PATH']):
 if not os.path.isfile(app.config['FLASK_HTPASSWD_PATH']):
-    raise FileNotFoundError("You need to specify a .htpasswd-file. The following file could not be located: " + app.config['FLASK_HTPASSWD_PATH'] + "!")
+    raise FileNotFoundError("You need to specify a .htpasswd-file."
+        f"The following file could not be located: {app.config['FLASK_HTPASSWD_PATH']}!")
 app.config['FLASK_SECRET'] = 'Hey Hey Kids, secure me!'
 app.config['FLASK_SECRET'] = 'Hey Hey Kids, secure me!'
 htpasswd = HtPasswdAuth(app)
 htpasswd = HtPasswdAuth(app)
 
 

+ 1 - 1
pycs/database/File.py

@@ -147,7 +147,7 @@ class File(NamedBaseModel):
 
 
         annot_query = File.results.any()
         annot_query = File.results.any()
 
 
-        if with_annotations == False:
+        if with_annotations is False:
             annot_query = ~annot_query
             annot_query = ~annot_query
 
 
         return result.filter(annot_query)
         return result.filter(annot_query)

+ 35 - 7
pycs/database/Project.py

@@ -4,6 +4,7 @@ import typing as T
 
 
 from datetime import datetime
 from datetime import datetime
 from eventlet import tpool
 from eventlet import tpool
+from sqlalchemy.sql import case
 
 
 from pycs import app
 from pycs import app
 from pycs import db
 from pycs import db
@@ -166,7 +167,7 @@ class Project(NamedBaseModel):
         return label, is_new
         return label, is_new
 
 
     @commit_on_return
     @commit_on_return
-    def bulk_create_labels(self, labels: T.List[T.Dict], clean_old_labels: bool = True):
+    def bulk_create_labels(self, labels: T.List[T.Dict]):
         """
         """
             Inserts a all labels at once.
             Inserts a all labels at once.
 
 
@@ -177,15 +178,42 @@ class Project(NamedBaseModel):
         if len(labels) == 0:
         if len(labels) == 0:
             return labels
             return labels
 
 
-        if clean_old_labels:
-            self.labels.delete()
-
         for label in labels:
         for label in labels:
             label["project_id"] = self.id
             label["project_id"] = self.id
 
 
         self.__check_labels(labels)
         self.__check_labels(labels)
-        app.logger.info(f"Inserting {len(labels):,d} labels")
-        db.engine.execute(Label.__table__.insert(), labels)
+
+
+        # first update existing labels
+        fields_to_update = (
+            ("name", Label.name),
+            ("hierarchy_level", Label.hierarchy_level),
+        )
+
+        updates = {
+            field: case(
+                {lab["reference"]: lab[key] for lab in labels},
+                value=Label.reference)
+
+            for key, field in fields_to_update
+        }
+
+        existing_labs = self.labels.filter(
+            Label.reference.in_([lab["reference"] for lab in labels])
+        )
+        app.logger.info(f"Updating {existing_labs.count():,d} labels")
+        existing_labs.update(updates, synchronize_session=False)
+
+        # then add new labels
+        references = {lab.reference for lab in self.labels.all()}
+        new_labels = [lab for lab in labels
+            if lab["reference"] not in references]
+
+        if len(new_labels) > 0:
+            app.logger.info(f"Inserting {len(new_labels):,d} new labels")
+            db.engine.execute(Label.__table__.insert(), new_labels)
+
+        # finally set parents correctly
         self.__set_parents(labels)
         self.__set_parents(labels)
 
 
         return labels
         return labels
@@ -304,7 +332,7 @@ class Project(NamedBaseModel):
         if with_annotations is not None:
         if with_annotations is not None:
             annot_query = File.results.any()
             annot_query = File.results.any()
 
 
-            if with_annotations == False:
+            if with_annotations is False:
                 annot_query = ~annot_query
                 annot_query = ~annot_query
 
 
             filters = filters + (annot_query,)
             filters = filters + (annot_query,)

+ 7 - 0
pycs/database/__init__.py

@@ -0,0 +1,7 @@
+from pycs.database.Collection import Collection
+from pycs.database.File import File
+from pycs.database.Label import Label
+from pycs.database.LabelProvider import LabelProvider
+from pycs.database.Model import Model
+from pycs.database.Project import Project
+from pycs.database.Result import Result

+ 18 - 29
pycs/frontend/endpoints/pipelines/EstimateBoundingBox.py

@@ -1,14 +1,14 @@
-import cv2
+import typing as T
 import uuid
 import uuid
+
+import cv2
 import numpy as np
 import numpy as np
-import typing as T
 
 
 from flask import abort
 from flask import abort
 from flask import make_response
 from flask import make_response
 from flask import request
 from flask import request
 from flask.views import View
 from flask.views import View
 
 
-from pycs import db
 from pycs.database.File import File
 from pycs.database.File import File
 from pycs.database.Result import Result
 from pycs.database.Result import Result
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.frontend.notifications.NotificationManager import NotificationManager
@@ -34,7 +34,7 @@ class EstimateBoundingBox(View):
         if 'x' not in request_data or 'y' not in request_data:
         if 'x' not in request_data or 'y' not in request_data:
             abort(400, "coordinates for the estimation are missing")
             abort(400, "coordinates for the estimation are missing")
 
 
-        x,y = map(request_data.get, "xy")
+        x, y = map(request_data.get, "xy")
 
 
         # get project
         # get project
         project = file.project
         project = file.project
@@ -56,13 +56,15 @@ class EstimateBoundingBox(View):
 
 
 
 
 def estimate(file_id: int, x: float, y: float) -> Result:
 def estimate(file_id: int, x: float, y: float) -> Result:
+    """ estimation function """
+
     file = File.query.get(file_id)
     file = File.query.get(file_id)
 
 
-    im = cv2.imread(file.absolute_path, cv2.IMREAD_GRAYSCALE)
+    image = cv2.imread(file.absolute_path, cv2.IMREAD_GRAYSCALE)
 
 
-    h, w = im.shape
+    h, w = image.shape
     pos = int(x * w), int(y * h)
     pos = int(x * w), int(y * h)
-    x0, y0, x1, y1 = detect(im, pos,
+    x0, y0, x1, y1 = detect(image, pos,
                             window_size=1000,
                             window_size=1000,
                             pixel_delta=50,
                             pixel_delta=50,
                             enlarge=1e-2,
                             enlarge=1e-2,
@@ -77,18 +79,19 @@ def estimate(file_id: int, x: float, y: float) -> Result:
 
 
     return file.create_result('pipeline', 'bounding-box', label=None, data=data)
     return file.create_result('pipeline', 'bounding-box', label=None, data=data)
 
 
-def detect(im: np.ndarray,
+def detect(image: np.ndarray,
            pos: T.Tuple[int, int],
            pos: T.Tuple[int, int],
            window_size: int = 1000,
            window_size: int = 1000,
            pixel_delta: int = 0,
            pixel_delta: int = 0,
            enlarge: float = -1) -> T.Tuple[int, int, int, int]:
            enlarge: float = -1) -> T.Tuple[int, int, int, int]:
-    # im = blur(im, 3)
+    """ detection function """
+    # image = blur(image, 3)
     x, y = pos
     x, y = pos
-    pixel = im[y, x]
+    pixel = image[y, x]
 
 
     min_pix, max_pix = pixel - pixel_delta, pixel + pixel_delta
     min_pix, max_pix = pixel - pixel_delta, pixel + pixel_delta
 
 
-    mask = np.logical_and(min_pix < im, im < max_pix).astype(np.float32)
+    mask = np.logical_and(min_pix < image, image < max_pix).astype(np.float32)
     # mask = open_close(mask)
     # mask = open_close(mask)
     # mask = blur(mask)
     # mask = blur(mask)
 
 
@@ -98,18 +101,19 @@ def detect(im: np.ndarray,
 
 
     sum_x, sum_y = window.sum(axis=0), window.sum(axis=1)
     sum_x, sum_y = window.sum(axis=0), window.sum(axis=1)
 
 
-    enlarge = int(enlarge * max(im.shape))
+    enlarge = int(enlarge * max(image.shape))
     (x0, x1), (y0, y1) = get_borders(sum_x, enlarge), get_borders(sum_y, enlarge)
     (x0, x1), (y0, y1) = get_borders(sum_x, enlarge), get_borders(sum_y, enlarge)
 
 
     x0 = max(x + x0 - pad, 0)
     x0 = max(x + x0 - pad, 0)
     y0 = max(y + y0 - pad, 0)
     y0 = max(y + y0 - pad, 0)
 
 
-    x1 = min(x + x1 - pad, im.shape[1])
-    y1 = min(y + y1 - pad, im.shape[0])
+    x1 = min(x + x1 - pad, image.shape[1])
+    y1 = min(y + y1 - pad, image.shape[0])
 
 
     return x0, y0, x1, y1
     return x0, y0, x1, y1
 
 
 def get_borders(arr, enlarge: int, eps=5e-1):
 def get_borders(arr, enlarge: int, eps=5e-1):
+    """ returns borders based on coordinate extrema """
     mid = len(arr) // 2
     mid = len(arr) // 2
 
 
     arr0, arr1 = arr[:mid], arr[mid:]
     arr0, arr1 = arr[:mid], arr[mid:]
@@ -130,18 +134,3 @@ def get_borders(arr, enlarge: int, eps=5e-1):
         upper = min(upper + enlarge, len(arr)-1)
         upper = min(upper + enlarge, len(arr)-1)
 
 
     return int(lower), int(upper)
     return int(lower), int(upper)
-
-
-"""
-def blur(im, sigma=5):
-    from skimage import filters
-    return filters.gaussian(im, sigma=sigma, preserve_range=True)
-
-def open_close(im, kernel_size=3):
-
-    kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
-
-    im = cv2.morphologyEx(im, cv2.MORPH_OPEN, kernel)
-    im = cv2.morphologyEx(im, cv2.MORPH_CLOSE, kernel)
-    return im
-"""

+ 0 - 1
pycs/frontend/endpoints/pipelines/PredictModel.py

@@ -12,7 +12,6 @@ from pycs.database.Result import Result
 from pycs.frontend.notifications.NotificationList import NotificationList
 from pycs.frontend.notifications.NotificationList import NotificationList
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.interfaces.MediaFile import MediaFile
 from pycs.interfaces.MediaFile import MediaFile
-from pycs.interfaces.MediaLabel import MediaLabel
 from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException

+ 0 - 2
pycs/frontend/endpoints/results/CopyResults.py

@@ -1,12 +1,10 @@
 from flask import abort
 from flask import abort
-from flask import jsonify
 from flask import make_response
 from flask import make_response
 from flask import request
 from flask import request
 from flask.views import View
 from flask.views import View
 
 
 from pycs import db
 from pycs import db
 from pycs.database.File import File
 from pycs.database.File import File
-from pycs.database.File import Result
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 
 
 
 

+ 173 - 3
pycs/management/result.py

@@ -1,10 +1,11 @@
 import click
 import click
 import flask
 import flask
+import simplejson as json
 
 
 from flask.cli import AppGroup
 from flask.cli import AppGroup
 
 
 from pycs import app
 from pycs import app
-from pycs.database.Project import Project
+from pycs import database as db
 
 
 result_cli = AppGroup("result", short_help="Result operations")
 result_cli = AppGroup("result", short_help="Result operations")
 
 
@@ -17,13 +18,13 @@ result_cli = AppGroup("result", short_help="Result operations")
 def export(project_id, output, indent):
 def export(project_id, output, indent):
     """ Export results for a specific project or for all projects """
     """ Export results for a specific project or for all projects """
     if project_id == "all":
     if project_id == "all":
-        projects = Project.query.all()
+        projects = db.Project.query.all()
         app.logger.info(f"Exporting results for all projects ({len(projects)})!")
         app.logger.info(f"Exporting results for all projects ({len(projects)})!")
         if output is None:
         if output is None:
             output = "output.json"
             output = "output.json"
 
 
     else:
     else:
-        project = Project.query.get(project_id)
+        project = db.Project.query.get(project_id)
         if project is None:
         if project is None:
             app.logger.error(f"Could not find project with ID {project_id}!")
             app.logger.error(f"Could not find project with ID {project_id}!")
             return
             return
@@ -58,3 +59,172 @@ def export(project_id, output, indent):
 
 
     with open(output, "w", encoding="utf-8") as out_f:
     with open(output, "w", encoding="utf-8") as out_f:
         flask.json.dump(results, out_f, app=app, indent=indent)
         flask.json.dump(results, out_f, app=app, indent=indent)
+
+
+
+
+@result_cli.command("restore")
+@click.argument("infile")
+@click.option("--dry-run", is_flag=True)
+def restore(infile, dry_run):
+
+    with open(infile) as f:
+        results = json.load(f)
+
+    for project_results in results:
+        project = db.Project.get_or_404(project_results["project_id"])
+        for file_results in project_results["files"]:
+            file = db.File.get_or_404(file_results["id"])
+
+            assert file.path == file_results["path"]
+
+            # first check for new and changed results
+            for _result in file_results["results"]:
+
+                if not _is_data_valid(**_result):
+                    continue
+
+                result = get_result_or_none(file, **_result)
+
+                user1 = _result["origin_user"]
+                data1 = _result["data"]
+                ref1 = (_result["label"] or {}).get("reference")
+                # lab1 = (_result["label"] or {}).get("id")
+
+
+                if result is None:
+                    # we have a new result entry
+                    if not dry_run:
+                        file.create_result(
+                            result_type="bounding-box",
+                            origin="user",
+                            origin_user=user1,
+                            label=ref1,
+                            data=data1,
+                            commit=True
+                        )
+                    print(" | ".join([
+                        f"Project #{project.id:< 6d}"
+                        f"File #{file.id:< 6d} [{file.name:^30s}]",
+                        "[New Result]",
+                        f"User: {user1 or '':<10s}",
+                        f"Data: {data1}, Label-Ref: {ref1}",
+                        ])
+                    )
+
+                    continue
+
+                assert result.file_id == _result["file_id"]
+                user0 = result.origin_user
+                data0 = result.data
+                ref0 = getattr(result.label, "reference", None)
+                # lab0 = getattr(result.label, "id", None)
+
+                is_same_data = _check_data(data0, data1)
+
+                if is_same_data and (ref0 == ref1 or ref1 is None):
+                    # nothing to change
+                    continue
+
+                print(" | ".join([
+                    f"Project #{project.id:< 6d}"
+                    f"File #{file.id:< 6d} [{file.name:^30s}]",
+                    ]), end=" | "
+                )
+                if not is_same_data:
+                    # data was updated
+                    print(" | ".join([
+                        "[Data updated]",
+                        f"User: {user1 or '':<10s}",
+                        f"Data: {data0} -> {data1}"
+                        ]), end=" | "
+                    )
+                    assert user1 is not None
+                    if not dry_run:
+                        result.origin_user = user1
+                        result.data = data1
+
+                if ref0 != ref1:
+                    assert user1 is not None
+                    if not dry_run:
+                        result.origin_user = user1
+                    if ref1 is None:
+                        # label was deleted
+                        print("[Label Deleted]")
+                        if not dry_run:
+                            result.label_id = None
+                    else:
+                        # label was updated
+                        print(" | ".join([
+                            "[Label updated]",
+                            f"User: {user0 or '':<10s} -> {user1 or '':<10s}",
+                            f"{ref0 or 'UNK':<6s} -> {ref1 or 'UNK':<6s}"
+                            ])
+                        )
+                        label = project.label_by_reference(ref1)
+                        if not dry_run:
+                            result.label_id = label.id
+                else:
+                    print()
+
+                if not dry_run:
+                    result.commit()
+
+            # then check for deleted results
+            for result in file.results.all():
+                if result.origin != "user" or result.type != "bounding-box":
+                    continue
+
+                found = False
+                for _result in file_results["results"]:
+                    if not _is_data_valid(**_result):
+                        continue
+
+                    if _check_data(result.data, _result["data"]):
+                        found = True
+                        break
+
+                if not found:
+                    print(" | ".join([
+                        f"Project #{project.id:< 6d}"
+                        f"File #{file.id:< 6d} [{file.name:^30s}]",
+                        "[Result deleted]",
+                        f"{result.data}",
+                        f"{result.label}",
+                        ])
+                    )
+
+                    if not dry_run:
+                        result.delete()
+
+def _is_data_valid(*, data, type, origin, **kwargs):
+
+    wh = (None, None) if data is None else (data["w"], data["h"])
+
+    return (type != "labeled-image" and
+        origin == "user" and
+        0 not in wh)
+
+def _check_data(data0, data1):
+
+    if None in (data0, data1):
+        return data0 == data1 == None
+
+    for key in data0:
+        if data1.get(key) != data0.get(key):
+            return False
+
+    return True
+
+def get_result_or_none(file: db.File, id: int, data: dict, **kwargs):
+
+    result = db.Result.query.filter(
+        db.Result.id==id, db.Result.file_id==file.id).one_or_none()
+
+    if result is not None:
+        return result
+
+    for other_results in file.results.all():
+        if _check_data(data, other_results.data):
+            # import pdb; pdb.set_trace()
+            return other_results

+ 1 - 0
requirements.txt

@@ -9,6 +9,7 @@ flask-sqlalchemy
 sqlalchemy_serializer
 sqlalchemy_serializer
 flask-migrate
 flask-migrate
 flask-htpasswd
 flask-htpasswd
+itsdangerous~=2.0.1
 python-socketio
 python-socketio
 munch
 munch
 scikit-image
 scikit-image

+ 26 - 3
tests/client/pipeline_tests.py

@@ -179,10 +179,33 @@ class LabelProviderPipelineTests:
     def test_label_loading_multiple(self):
     def test_label_loading_multiple(self):
 
 
         for i in range(3):
         for i in range(3):
-            self.post(self.url, json=dict(execute=True))
-            self.wait_for_bg_jobs()
+            self.test_label_loading()
+
+    def test_multiple_loading_does_not_delete_existing_labels(self):
+        self.test_label_loading()
+
+        file = self.project.files.first()
+
+        def _check():
+            for res in file.results.all():
+                self.assertIsNotNone(res.label_id)
+
+        for label in self.project.labels:
+            file.create_result(
+                origin="user",
+                result_type="bounding-box",
+                label=label,
+                data=dict(x=0, y=0, w=0.2, h=0.3),
+            )
+
+        file.commit()
+
+        _check()
+
+        for i in range(3):
+            self.test_label_loading()
+            _check()
 
 
-            self.assertEqual(self.n_labels, self.project.labels.count())
 
 
 class SimpleLabelProviderPipelineTests(LabelProviderPipelineTests, _BasePipelineTests):
 class SimpleLabelProviderPipelineTests(LabelProviderPipelineTests, _BasePipelineTests):