1
1
Эх сурвалжийг харах

Merge branch '126-clean-model-interface' into 'master'

Resolve "clean model interface"

Closes #126

See merge request troebs/pycs!110
Eric Tröbs 3 жил өмнө
parent
commit
774084dae7

+ 15 - 12
models/fixed_model/Pipeline.py

@@ -1,10 +1,9 @@
 from json import dump, load
 from os import path
 from time import sleep
-from typing import List
 
-from pycs.interfaces.AnnotatedMediaFile import AnnotatedMediaFile
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.Pipeline import Pipeline as Interface
 
 
@@ -16,12 +15,10 @@ class Pipeline(Interface):
     def close(self):
         print('fmv1 close')
 
-    def execute(self, file: MediaFile) -> List[dict]:
+    def execute(self, storage: MediaStorage, file: MediaFile):
         print('fmv1 execute')
-        sleep(0.01)
 
         data_file = path.join(self.root_folder, 'data.json')
-
         if path.exists(data_file):
             with open(data_file, 'r') as f:
                 result = load(f)
@@ -29,17 +26,23 @@ class Pipeline(Interface):
             result = {}
 
         if file.path in result:
-            return result[file.path]
-        else:
-            return []
+            for r in result[file.path]:
+                if r['type'] == 'MediaBoundingBox':
+                    file.add_bounding_box(r['x'], r['y'], r['w'], r['h'], r['label'])
+                if r['type'] == 'MediaImageLabel':
+                    file.set_image_label(r['label'])
 
-    def fit(self, files: List[AnnotatedMediaFile]):
+    def fit(self, storage: MediaStorage):
         print('fmv1 fit')
-        sleep(5)
+
+        for i in range(10):
+            yield i / 10
+            sleep(1)
 
         result = {}
-        for f in files:
-            result[f.path] = f.results
+        for f in storage.files().iter():
+            result[f.path] = list(map(lambda r: dict(r.__dict__, **{'type': type(r).__name__}),
+                                      f.results()))
 
         data_file = path.join(self.root_folder, 'data.json')
         with open(data_file, 'w') as file:

+ 20 - 18
models/haarcascade_frontalface_default/Pipeline.py

@@ -5,6 +5,7 @@ from urllib.request import urlretrieve
 import cv2
 
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.Pipeline import Pipeline as Interface
 
 
@@ -33,38 +34,39 @@ class Pipeline(Interface):
             self.create_collection('none', 'no face detected')
         ]
 
-    def execute(self, file: MediaFile) -> List[dict]:
+    def execute(self, storage: MediaStorage, file: MediaFile):
         print('hcffdv1 execute')
 
-        # load image and convert to grayscale
+        # load image, convert to grayscale, scale down
         image = cv2.imread(file.path)
-        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
 
+        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
         height, width = gray.shape
-        min_size = int(min(width, height) / 10)
+
+        scale_factor = min(2048 / width, 2048 / height, 1.0)
+        scale_height, scale_width = int(height * scale_factor), int(width * scale_factor)
+        scaled = cv2.resize(gray, (scale_width, scale_height))
 
         # detect faces
         faces = self.face_cascade.detectMultiScale(
-            gray,
+            scaled,
             scaleFactor=1.1,
             minNeighbors=5,
-            minSize=(min_size, min_size)
+            minSize=(192, 192)
         )
 
         # convert faces to result list
-        result = []
+        result = False
+
         for x, y, w, h in faces:
-            result.append(self.create_bounding_box_result(
-                x / width,
-                y / height,
-                w / width,
-                h / height
-            ))
+            file.add_bounding_box(x / scale_width,
+                                  y / scale_height,
+                                  w / scale_width,
+                                  h / scale_height)
+            result = True
 
         # set file collection
-        if len(result) > 0:
-            result.append(self.create_collection_result('face'))
+        if result:
+            file.set_collection('face')
         else:
-            result.append(self.create_collection_result('none'))
-
-        return result
+            file.set_collection('none')

+ 14 - 4
pycs/database/Collection.py

@@ -1,5 +1,5 @@
 from contextlib import closing
-from typing import List
+from typing import List, Iterator
 
 from pycs.database.File import File
 
@@ -51,7 +51,7 @@ class Collection:
                            (self.project_id, self.identifier))
             return cursor.fetchone()[0]
 
-    def files(self, offset=0, limit=-1) -> List[File]:
+    def files(self, offset: int = 0, limit: int = -1) -> List[File]:
         """
         get a list of files associated with this collection
 
@@ -59,6 +59,16 @@ class Collection:
         :param limit: file limit
         :return: list of files
         """
+        return list(self.files_iter(offset, limit))
+
+    def files_iter(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
+        """
+        get an iterator of files associated with this collection
+
+        :param offset: file offset
+        :param limit: file limit
+        :return: iterator of files
+        """
         with closing(self.database.con.cursor()) as cursor:
             cursor.execute('''
                 SELECT * FROM files
@@ -66,7 +76,7 @@ class Collection:
                 ORDER BY id ASC LIMIT ? OFFSET ?
                 ''', (self.project_id, self.identifier, limit, offset))
 
-            return list(map(
+            return map(
                 lambda row: File(self.database, row),
                 cursor.fetchall()
-            ))
+            )

+ 110 - 104
pycs/database/Database.py

@@ -18,7 +18,7 @@ class Database:
     opens an sqlite database and allows to access several objects
     """
 
-    def __init__(self, path: str = ':memory:', discovery=True):
+    def __init__(self, path: str = ':memory:', initialization=True, discovery=True):
         """
         opens or creates a given sqlite database and creates all required tables
 
@@ -31,111 +31,110 @@ class Database:
         self.con = sqlite3.connect(path)
         self.con.execute("PRAGMA foreign_keys = ON")
 
-        # create tables
-        with closing(self.con.cursor()) as cursor:
-            cursor.execute('''
-                CREATE TABLE IF NOT EXISTS models (
-                    id          INTEGER PRIMARY KEY,
-                    name        TEXT                NOT NULL,
-                    description TEXT,
-                    root_folder TEXT                NOT NULL UNIQUE,
-                    supports    TEXT                NOT NULL
-                )
-            ''')
-            cursor.execute('''
-                CREATE TABLE IF NOT EXISTS label_providers (
-                    id          INTEGER PRIMARY KEY,
-                    name        TEXT                NOT NULL,
-                    description TEXT,
-                    root_folder TEXT                NOT NULL UNIQUE
-                )
-            ''')
-
-            cursor.execute('''
-                CREATE TABLE IF NOT EXISTS projects (
-                    id             INTEGER PRIMARY KEY,
-                    name           TEXT                NOT NULL,
-                    description    TEXT,
-                    created        INTEGER             NOT NULL,
-                    model          INTEGER,
-                    label_provider INTEGER,
-                    root_folder    TEXT                NOT NULL UNIQUE,
-                    external_data  BOOL                NOT NULL,
-                    data_folder    TEXT                NOT NULL,
-                    FOREIGN KEY (model) REFERENCES models(id)
-                        ON UPDATE CASCADE ON DELETE SET NULL,
-                    FOREIGN KEY (label_provider) REFERENCES label_providers(id)
-                        ON UPDATE CASCADE ON DELETE SET NULL
-                )
-            ''')
-            cursor.execute('''
-                CREATE TABLE IF NOT EXISTS labels (
-                    id        INTEGER PRIMARY KEY,
-                    project   INTEGER             NOT NULL,
-                    parent    INTEGER,
-                    created   INTEGER             NOT NULL,
-                    reference TEXT,
-                    name      TEXT                NOT NULL,
-                    FOREIGN KEY (project) REFERENCES projects(id)
-                        ON UPDATE CASCADE ON DELETE CASCADE,
-                    FOREIGN KEY (parent) REFERENCES labels(id)
-                        ON UPDATE CASCADE ON DELETE SET NULL,
-                    UNIQUE(project, reference)
-                )
-            ''')
-            cursor.execute('''
-                CREATE TABLE IF NOT EXISTS collections (
-                    id          INTEGER          PRIMARY KEY,
-                    project     INTEGER NOT NULL,
-                    reference   TEXT    NOT NULL,
-                    name        TEXT    NOT NULL,
-                    description TEXT,
-                    position    INTEGER NOT NULL,
-                    autoselect  INTEGER NOT NULL,
-                    FOREIGN KEY (project) REFERENCES projects(id)
-                        ON UPDATE CASCADE ON DELETE CASCADE,
-                    UNIQUE(project, reference)
-                )
-            ''')
-            cursor.execute('''
-                CREATE TABLE IF NOT EXISTS files (
-                    id         INTEGER PRIMARY KEY,
-                    uuid       TEXT                NOT NULL,
-                    project    INTEGER             NOT NULL,
-                    collection INTEGER,
-                    type       TEXT                NOT NULL,
-                    name       TEXT                NOT NULL,
-                    extension  TEXT                NOT NULL,
-                    size       INTEGER             NOT NULL,
-                    created    INTEGER             NOT NULL,
-                    path       TEXT                NOT NULL,
-                    frames     INTEGER,
-                    fps        FLOAT,
-                    FOREIGN KEY (project) REFERENCES projects(id)
-                        ON UPDATE CASCADE ON DELETE CASCADE,
-                    FOREIGN KEY (collection) REFERENCES collections(id)
-                        ON UPDATE CASCADE ON DELETE SET NULL,
-                    UNIQUE(project, path)
-                )
-            ''')
-            cursor.execute('''
-                CREATE TABLE IF NOT EXISTS results (
-                    id     INTEGER PRIMARY KEY,
-                    file   INTEGER             NOT NULL,
-                    origin TEXT                NOT NULL,
-                    type   TEXT                NOT NULL,
-                    label  INTEGER,
-                    data   TEXT                NOT NULL,
-                    FOREIGN KEY (file) REFERENCES files(id)
-                        ON UPDATE CASCADE ON DELETE CASCADE
-                )
-            ''')
-            # cursor.execute('''
-            #     CREATE INDEX IF NOT EXISTS idx_results_label ON results(label)
-            # ''')
+        if initialization:
+            # create tables
+            with self:
+                with closing(self.con.cursor()) as cursor:
+                    cursor.execute('''
+                        CREATE TABLE IF NOT EXISTS models (
+                            id          INTEGER PRIMARY KEY,
+                            name        TEXT                NOT NULL,
+                            description TEXT,
+                            root_folder TEXT                NOT NULL UNIQUE,
+                            supports    TEXT                NOT NULL
+                        )
+                    ''')
+                    cursor.execute('''
+                        CREATE TABLE IF NOT EXISTS label_providers (
+                            id          INTEGER PRIMARY KEY,
+                            name        TEXT                NOT NULL,
+                            description TEXT,
+                            root_folder TEXT                NOT NULL UNIQUE
+                        )
+                    ''')
+
+                    cursor.execute('''
+                        CREATE TABLE IF NOT EXISTS projects (
+                            id             INTEGER PRIMARY KEY,
+                            name           TEXT                NOT NULL,
+                            description    TEXT,
+                            created        INTEGER             NOT NULL,
+                            model          INTEGER,
+                            label_provider INTEGER,
+                            root_folder    TEXT                NOT NULL UNIQUE,
+                            external_data  BOOL                NOT NULL,
+                            data_folder    TEXT                NOT NULL,
+                            FOREIGN KEY (model) REFERENCES models(id)
+                                ON UPDATE CASCADE ON DELETE SET NULL,
+                            FOREIGN KEY (label_provider) REFERENCES label_providers(id)
+                                ON UPDATE CASCADE ON DELETE SET NULL
+                        )
+                    ''')
+                    cursor.execute('''
+                        CREATE TABLE IF NOT EXISTS labels (
+                            id        INTEGER PRIMARY KEY,
+                            project   INTEGER             NOT NULL,
+                            parent    INTEGER,
+                            created   INTEGER             NOT NULL,
+                            reference TEXT,
+                            name      TEXT                NOT NULL,
+                            FOREIGN KEY (project) REFERENCES projects(id)
+                                ON UPDATE CASCADE ON DELETE CASCADE,
+                            FOREIGN KEY (parent) REFERENCES labels(id)
+                                ON UPDATE CASCADE ON DELETE SET NULL,
+                            UNIQUE(project, reference)
+                        )
+                    ''')
+                    cursor.execute('''
+                        CREATE TABLE IF NOT EXISTS collections (
+                            id          INTEGER          PRIMARY KEY,
+                            project     INTEGER NOT NULL,
+                            reference   TEXT    NOT NULL,
+                            name        TEXT    NOT NULL,
+                            description TEXT,
+                            position    INTEGER NOT NULL,
+                            autoselect  INTEGER NOT NULL,
+                            FOREIGN KEY (project) REFERENCES projects(id)
+                                ON UPDATE CASCADE ON DELETE CASCADE,
+                            UNIQUE(project, reference)
+                        )
+                    ''')
+                    cursor.execute('''
+                        CREATE TABLE IF NOT EXISTS files (
+                            id         INTEGER PRIMARY KEY,
+                            uuid       TEXT                NOT NULL,
+                            project    INTEGER             NOT NULL,
+                            collection INTEGER,
+                            type       TEXT                NOT NULL,
+                            name       TEXT                NOT NULL,
+                            extension  TEXT                NOT NULL,
+                            size       INTEGER             NOT NULL,
+                            created    INTEGER             NOT NULL,
+                            path       TEXT                NOT NULL,
+                            frames     INTEGER,
+                            fps        FLOAT,
+                            FOREIGN KEY (project) REFERENCES projects(id)
+                                ON UPDATE CASCADE ON DELETE CASCADE,
+                            FOREIGN KEY (collection) REFERENCES collections(id)
+                                ON UPDATE CASCADE ON DELETE SET NULL,
+                            UNIQUE(project, path)
+                        )
+                    ''')
+                    cursor.execute('''
+                        CREATE TABLE IF NOT EXISTS results (
+                            id     INTEGER PRIMARY KEY,
+                            file   INTEGER             NOT NULL,
+                            origin TEXT                NOT NULL,
+                            type   TEXT                NOT NULL,
+                            label  INTEGER,
+                            data   TEXT,
+                            FOREIGN KEY (file) REFERENCES files(id)
+                                ON UPDATE CASCADE ON DELETE CASCADE
+                        )
+                    ''')
 
-        # run discovery modules
         if discovery:
+            # run discovery modules
             with self:
                 discover_models(self.con)
                 discover_label_providers(self.con)
@@ -143,8 +142,15 @@ class Database:
     def close(self):
         self.con.close()
 
+    def copy(self):
+        return Database(self.path, initialization=False, discovery=False)
+
+    def commit(self):
+        self.con.commit()
+
     def __enter__(self):
         self.con.__enter__()
+        return self
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.con.__exit__(exc_type, exc_val, exc_tb)

+ 16 - 2
pycs/database/File.py

@@ -57,7 +57,7 @@ class File:
                            (collection_id, self.identifier))
             self.collection_id = collection_id
 
-    def set_collection_by_reference(self, collection_reference: str):
+    def set_collection_by_reference(self, collection_reference: Optional[str]):
         """
         set this file's collection
 
@@ -204,7 +204,7 @@ class File:
 
             return None
 
-    def create_result(self, origin, result_type, label, data):
+    def create_result(self, origin, result_type, label, data=None):
         """
         create a result
 
@@ -224,3 +224,17 @@ class File:
             ''', (self.identifier, origin, result_type, label, data))
 
             return self.result(cursor.lastrowid)
+
+    def remove_results(self, origin='pipeline') -> List[Result]:
+        with closing(self.database.con.cursor()) as cursor:
+            cursor.execute('''
+                SELECT * FROM results WHERE file = ? AND origin = ?
+            ''', (self.identifier, origin))
+
+            results = list(map(lambda row: Result(self.database, row), cursor.fetchall()))
+
+            cursor.execute('''
+                DELETE FROM results WHERE file = ? AND origin = ?
+            ''', (self.identifier, origin))
+
+            return results

+ 46 - 4
pycs/database/Project.py

@@ -1,7 +1,7 @@
 from contextlib import closing
 from os.path import join
 from time import time
-from typing import List, Optional, Tuple
+from typing import List, Optional, Tuple, Iterator
 
 from pycs.database.Collection import Collection
 from pycs.database.File import File
@@ -143,6 +143,23 @@ class Project:
 
             return None
 
+    def collection_by_reference(self, reference: str):
+        """
+        get a collection using its reference string
+
+        :param reference: reference string
+        :return: collection
+        """
+        with closing(self.database.con.cursor()) as cursor:
+            cursor.execute('SELECT * FROM collections WHERE reference = ? AND project = ?',
+                           (reference, self.identifier))
+            row = cursor.fetchone()
+
+            if row is not None:
+                return Collection(self.database, row)
+
+            return None
+
     def create_collection(self,
                           reference: str,
                           name: str,
@@ -216,7 +233,7 @@ class Project:
             cursor.execute('SELECT COUNT(*) FROM files WHERE project = ?', [self.identifier])
             return cursor.fetchone()[0]
 
-    def files(self, offset=0, limit=-1) -> List[File]:
+    def files(self, offset: int = 0, limit: int = -1) -> List[File]:
         """
         get a list of files associated with this project
 
@@ -224,14 +241,39 @@ class Project:
         :param limit: file limit
         :return: list of files
         """
+        return list(self.files_iter(offset, limit))
+
+    def files_iter(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
+        """
+        get an iterator of files associated with this project
+
+        :param offset: file offset
+        :param limit: file limit
+        :return: iterator of files
+        """
         with closing(self.database.con.cursor()) as cursor:
             cursor.execute('SELECT * FROM files WHERE project = ? ORDER BY id ASC LIMIT ? OFFSET ?',
                            (self.identifier, limit, offset))
 
-            return list(map(
+            return map(
                 lambda row: File(self.database, row),
                 cursor.fetchall()
-            ))
+            )
+
+    def count_files_without_results(self) -> int:
+        """
+        count files without associated results
+
+        :return: count
+        """
+        with closing(self.database.con.cursor()) as cursor:
+            cursor.execute('''
+                SELECT COUNT(*)
+                FROM files
+                LEFT JOIN results ON files.id = results.file
+                WHERE files.project = ? AND results.id IS NULL
+            ''', [self.identifier])
+            return cursor.fetchone()[0]
 
     def files_without_results(self) -> List[File]:
         """

+ 1 - 1
pycs/database/Result.py

@@ -16,7 +16,7 @@ class Result:
         self.origin = row[2]
         self.type = row[3]
         self.label = row[4]
-        self.data = loads(row[5])
+        self.data = loads(row[5]) if row[5] is not None else None
 
     def remove(self):
         """

+ 26 - 13
pycs/frontend/endpoints/pipelines/FitModel.py

@@ -1,10 +1,8 @@
-from contextlib import closing
-
 from flask import make_response, request, abort
 from flask.views import View
 
 from pycs.database.Database import Database
-from pycs.interfaces.AnnotatedMediaFile import AnnotatedMediaFile
+from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
 from pycs.jobs.JobRunner import JobRunner
 from pycs.util.PipelineUtil import load_from_root_folder as load_pipeline
@@ -34,24 +32,39 @@ class FitModel(View):
         if project is None:
             return abort(404)
 
-        # get model
-        model = project.model()
-
-        # get data and results
-        files = list(map(AnnotatedMediaFile, project.files()))
-
         # create job
         try:
             self.jobs.run(project,
                           'Model Interaction',
                           f'{project.name} (fit model with new data)',
                           f'{project.name}/model-interaction',
-                          self.load_and_fit, model, files)
+                          self.load_and_fit, self.db, project.identifier)
         except JobGroupBusyException:
             return abort(400)
 
         return make_response()
 
-    def load_and_fit(self, model, files):
-        with closing(load_pipeline(model.root_folder)) as pipeline:
-            pipeline.fit(files)
+    @staticmethod
+    def load_and_fit(database: Database, project_id: int):
+        db = None
+        pipeline = None
+
+        # create new database instance
+        try:
+            db = database.copy()
+            project = db.project(project_id)
+            model = project.model()
+            storage = MediaStorage(db, project_id)
+
+            # load pipeline
+            try:
+                pipeline = load_pipeline(model.root_folder)
+                yield from pipeline.fit(storage)
+            except TypeError:
+                pass
+            finally:
+                if pipeline is not None:
+                    pipeline.close()
+        finally:
+            if db is not None:
+                db.close()

+ 7 - 34
pycs/frontend/endpoints/pipelines/PredictFile.py

@@ -2,9 +2,9 @@ from flask import make_response, request, abort
 from flask.views import View
 
 from pycs.database.Database import Database
-from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel
+from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel as Predict
+from pycs.frontend.notifications.NotificationList import NotificationList
 from pycs.frontend.notifications.NotificationManager import NotificationManager
-from pycs.interfaces.MediaFile import MediaFile
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
 from pycs.jobs.JobRunner import JobRunner
 
@@ -34,47 +34,20 @@ class PredictFile(View):
         if file is None:
             return abort(404)
 
-        media_file = MediaFile(file)
-
         # get project and model
         project = file.project()
-        model = project.model()
 
         # create job
-        def store(index, length, result):
-            with self.db:
-                for remove in file.results():
-                    if remove.origin == 'pipeline':
-                        remove.remove()
-                        self.nm.remove_result(remove)
-
-                for entry in result:
-                    file_type = entry['type']
-                    del entry['type']
-
-                    if 'label' in entry:
-                        label = entry['label']
-                        del entry['label']
-                    else:
-                        label = None
-
-                    if file_type == 'labeled-image':
-                        for remove in file.results():
-                            remove.remove()
-                            self.nm.remove_result(remove)
-
-                    created = file.create_result('pipeline', file_type, label, entry)
-                    self.nm.create_result(created)
-
-            return (index + 1) / length
-
         try:
+            notifications = NotificationList(self.nm)
+
             self.jobs.run(project,
                           'Model Interaction',
                           f'{project.name} (create predictions)',
                           f'{project.name}/model-interaction',
-                          PredictModel.load_and_predict, model, [media_file],
-                          progress=store)
+                          Predict.load_and_predict,
+                          self.db, project.identifier, [file], notifications,
+                          progress=Predict.progress)
         except JobGroupBusyException:
             return abort(400)
 

+ 64 - 62
pycs/frontend/endpoints/pipelines/PredictModel.py

@@ -1,11 +1,13 @@
-from contextlib import closing
+from typing import Any
 
 from flask import make_response, request, abort
 from flask.views import View
 
 from pycs.database.Database import Database
+from pycs.frontend.notifications.NotificationList import NotificationList
 from pycs.frontend.notifications.NotificationManager import NotificationManager
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
 from pycs.jobs.JobRunner import JobRunner
 from pycs.util.PipelineUtil import load_from_root_folder as load_pipeline
@@ -36,76 +38,76 @@ class PredictModel(View):
         if project is None:
             return abort(404)
 
-        # get model
-        model = project.model()
-
-        # get data and results
-        if data['predict'] == 'new':
-            files = project.files_without_results()
-        else:
-            files = project.files()
-
-        objects = list(map(MediaFile, files))
-
         # create job
-        def store(index, length, result):
-            # get file from list
-            file = files[index]
-
-            # start transaction
-            with self.db:
-                # remove current results from file
-                for remove in file.results():
-                    if remove.origin == 'pipeline':
-                        remove.remove()
-                        self.nm.remove_result(remove)
-
-                # iterate over result entries
-                for entry in result:
-                    # extract entry type
-                    entry_type = entry['type']
-                    del entry['type']
-
-                    # update file collection
-                    if entry_type == 'collection':
-                        file.set_collection_by_reference(entry['reference'])
-                        self.nm.edit_file(file)
-                        continue
-
-                    # extract label from entry
-                    if 'label' in entry:
-                        label = entry['label']
-                        del entry['label']
-                    else:
-                        label = None
-
-                    # if entry_type == 'labeled-image':
-                    #     for remove in file.results():
-                    #         remove.remove()
-                    #         self.nm.remove_result(remove)
-
-                    # add result
-                    created = files[index].create_result('pipeline', entry_type, label, entry)
-                    self.nm.create_result(created)
-
-            return (index + 1) / length
-
         try:
+            notifications = NotificationList(self.nm)
+
             self.jobs.run(project,
                           'Model Interaction',
                           f'{project.name} (create predictions)',
                           f'{project.name}/model-interaction',
-                          self.load_and_predict, model, objects,
-                          progress=store)
+                          self.load_and_predict,
+                          self.db, project.identifier, data['predict'], notifications,
+                          progress=self.progress)
         except JobGroupBusyException:
             return abort(400)
 
         return make_response()
 
     @staticmethod
-    def load_and_predict(model, files):
-        with closing(load_pipeline(model.root_folder)) as pipeline:
-            length = len(files)
-            for index in range(length):
-                result = pipeline.execute(files[index])
-                yield index, length, result
+    def load_and_predict(database: Database, project_id: int, file_filter: Any,
+                         notifications: NotificationList):
+        db = None
+        pipeline = None
+
+        # create new database instance
+        try:
+            db = database.copy()
+            project = db.project(project_id)
+            model = project.model()
+            storage = MediaStorage(db, project_id, notifications)
+
+            # create a list of MediaFile
+            if isinstance(file_filter, str):
+                if file_filter == 'new':
+                    length = project.count_files_without_results()
+                    files = map(lambda f: MediaFile(f, notifications),
+                                project.files_without_results())
+                else:
+                    length = project.count_files()
+                    files = map(lambda f: MediaFile(f, notifications),
+                                project.files())
+            else:
+                files = map(lambda f: MediaFile(project.file(f.identifier), notifications),
+                            file_filter)
+                length = len(file_filter)
+
+            # load pipeline
+            try:
+                pipeline = load_pipeline(model.root_folder)
+
+                # iterate over files
+                index = 0
+                for file in files:
+                    # remove old predictions
+                    file.remove_predictions()
+
+                    # create new predictions
+                    pipeline.execute(storage, file)
+
+                    # commit changes and yield progress
+                    db.commit()
+                    yield index / length, notifications
+
+                    index += 1
+            finally:
+                if pipeline is not None:
+                    pipeline.close()
+        finally:
+            if db is not None:
+                db.close()
+
+    @staticmethod
+    def progress(progress: float, notifications: NotificationList):
+        notifications.fire()
+        return progress

+ 4 - 3
pycs/frontend/endpoints/results/GetProjectResults.py

@@ -2,7 +2,7 @@ from flask import abort, jsonify
 from flask.views import View
 
 from pycs.database.Database import Database
-from pycs.interfaces.AnnotatedMediaFile import AnnotatedMediaFile
+from pycs.interfaces.MediaStorage import MediaStorage
 
 
 class GetProjectResults(View):
@@ -22,8 +22,9 @@ class GetProjectResults(View):
         if project is None:
             return abort(404)
 
-        # get results
-        files = list(map(AnnotatedMediaFile, project.files()))
+        # map media files to a dict
+        ms = MediaStorage(self.db, project.identifier, None)
+        files = list(map(lambda f: f.serialize(), ms.files().iter()))
 
         # return result
         return jsonify(files)

+ 16 - 0
pycs/frontend/notifications/NotificationList.py

@@ -0,0 +1,16 @@
+from pycs.frontend.notifications.NotificationManager import NotificationManager
+
+
+class NotificationList:
+    def __init__(self, nm: NotificationManager):
+        self.__list = []
+        self.nm = nm
+
+    def add(self, fun: callable, *params):
+        self.__list.append((fun, *params))
+
+    def fire(self):
+        for fun, *params in self.__list:
+            fun(*params)
+
+        self.__list = []

+ 0 - 20
pycs/interfaces/AnnotatedMediaFile.py

@@ -1,20 +0,0 @@
-from pycs.database.File import File
-from pycs.interfaces.MediaFile import MediaFile
-
-
-class AnnotatedMediaFile(MediaFile):
-    # pylint: disable=too-few-public-methods
-    """
-    contains various attributes of a saved media file including annotations
-    """
-
-    def __init__(self, file: File):
-        super().__init__(file)
-
-        self.results = []
-        for result in file.results():
-            if result.origin == 'user':
-                self.results.append({**{
-                    'type': result.type,
-                    'label': result.label
-                }, **result.data})

+ 2 - 2
pycs/interfaces/LabelProvider.py

@@ -1,4 +1,4 @@
-import typing
+from typing import List
 
 
 class LabelProvider:
@@ -23,7 +23,7 @@ class LabelProvider:
         """
         raise NotImplementedError
 
-    def get_labels(self) -> typing.List[dict]:
+    def get_labels(self) -> List[dict]:
         """
         return all available labels
 

+ 13 - 0
pycs/interfaces/MediaBoundingBox.py

@@ -0,0 +1,13 @@
+from pycs.database.Result import Result
+
+
+class MediaBoundingBox:
+    def __init__(self, result: Result):
+        self.x = result.data['x']
+        self.y = result.data['y']
+        self.w = result.data['w']
+        self.h = result.data['h']
+        self.label = result.label
+
+    def serialize(self) -> dict:
+        return dict({'type': 'bounding-box'}, **self.__dict__)

+ 104 - 2
pycs/interfaces/MediaFile.py

@@ -1,15 +1,23 @@
 from os import path, getcwd
+from typing import Optional, List, Union
 
 from pycs.database.File import File
+from pycs.database.Result import Result
+from pycs.frontend.notifications.NotificationList import NotificationList
+from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
+from pycs.interfaces.MediaImageLabel import MediaImageLabel
+from pycs.interfaces.MediaLabel import MediaLabel
 
 
 class MediaFile:
-    # pylint: disable=too-few-public-methods
     """
     contains various attributes of a saved media file
     """
 
-    def __init__(self, file: File):
+    def __init__(self, file: File, notifications: NotificationList):
+        self.__file = file
+        self.__notifications = notifications
+
         self.type = file.type
         self.size = file.size
         self.frames = file.frames
@@ -19,3 +27,97 @@ class MediaFile:
             self.path = file.path
         else:
             self.path = path.join(getcwd(), file.path)
+
+    def set_collection(self, reference: Optional[str]):
+        """
+        set this file's collection
+
+        :param reference: use None to remove this file's collection
+        """
+        self.__file.set_collection_by_reference(reference)
+        self.__notifications.add(self.__notifications.nm.edit_file, self.__file)
+
+    def set_image_label(self, label: Union[int, MediaLabel]):
+        """
+        create a labeled-image result
+
+        :param label: label identifier
+        """
+        if label is not None and isinstance(label, MediaLabel):
+            label = label.identifier
+
+        created = self.__file.create_result('pipeline', 'labeled-image', label)
+        self.__notifications.add(self.__notifications.nm.create_result, created)
+
+    def add_bounding_box(self, x: float, y: float, w: float, h: float,
+                         label: Union[int, MediaLabel] = None, frame: int = None):
+        """
+        create a bounding-box result
+
+        :param x: relative x coordinate [0, 1]
+        :param y: relative y coordinate [0, 1]
+        :param w: relative width [0, 1]
+        :param h: relative height [0, 1]
+        :param label: label
+        :param frame: frame index
+        """
+        result = {
+            'x': x,
+            'y': y,
+            'w': w,
+            'h': h
+        }
+        if frame is not None:
+            result['frame'] = frame
+
+        if label is not None and isinstance(label, MediaLabel):
+            label = label.identifier
+
+        created = self.__file.create_result('pipeline', 'bounding-box', label, result)
+        self.__notifications.add(self.__notifications.nm.create_result, created)
+
+    def remove_predictions(self):
+        """
+        remove and return all predictions added from pipelines
+        """
+        removed = self.__file.remove_results(origin='pipeline')
+        for r in removed:
+            self.__notifications.add(self.__notifications.nm.remove_result, r)
+
+    def __get_results(self, origin: str) -> List[Union[MediaImageLabel, MediaBoundingBox]]:
+        def map_r(result: Result) -> Union[MediaImageLabel, MediaBoundingBox]:
+            if result.type == 'labeled-image':
+                return MediaImageLabel(result)
+            else:
+                return MediaBoundingBox(result)
+
+        return list(map(map_r,
+                        filter(lambda r: r.origin == origin,
+                               self.__file.results())))
+
+    def results(self) -> List[Union[MediaImageLabel, MediaBoundingBox]]:
+        """
+        receive results added by users
+
+        :return: list of results
+        """
+        return self.__get_results('user')
+
+    def predictions(self) -> List[Union[MediaImageLabel, MediaBoundingBox]]:
+        """
+        receive results added by pipelines
+
+        :return: list of predictions
+        """
+        return self.__get_results('pipeline')
+
+    def serialize(self) -> dict:
+        return {
+            'type': self.type,
+            'size': self.size,
+            'frames': self.frames,
+            'fps': self.fps,
+            'path': self.path,
+            'results': list(map(lambda r: r.serialize(), self.results())),
+            'predictions': list(map(lambda r: r.serialize(), self.predictions())),
+        }

+ 57 - 0
pycs/interfaces/MediaFileList.py

@@ -0,0 +1,57 @@
+from typing import List, Iterator
+
+from pycs.database.Project import Project
+from pycs.frontend.notifications.NotificationList import NotificationList
+from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaLabel import MediaLabel
+
+
+class MediaFileList:
+    """
+    helper class to filter and receive MediaFile elements
+    """
+
+    def __init__(self, project: Project, notifications: NotificationList):
+        self.__project = project
+        self.__notifications = notifications
+        self.__collection = None
+        self.__label = None
+
+    # TODO pydoc
+    def filter_collection(self, collection_reference: str):
+        self.__collection = collection_reference
+        return self
+
+    # TODO pydoc
+    def filter_label(self, label: MediaLabel):
+        self.__label = label
+        return self
+
+    def iter(self) -> Iterator[MediaFile]:
+        """
+        receive an iterator of files
+
+        :return: iterator of files
+        """
+        if self.__collection is not None:
+            source = self.__project.collection_by_reference(self.__collection)
+        else:
+            source = self.__project
+
+        if self.__label is None:
+            for file in source.files_iter():
+                yield MediaFile(file, self.__notifications)
+        else:
+            for file in source.files_iter():
+                for result in file.results():
+                    if result.label == self.__label:
+                        yield MediaFile(file, self.__notifications)
+                        break
+
+    def list(self) -> List[MediaFile]:
+        """
+        receive a list of files
+
+        :return: list of files
+        """
+        return list(self.iter())

+ 9 - 0
pycs/interfaces/MediaImageLabel.py

@@ -0,0 +1,9 @@
+from pycs.database.Result import Result
+
+
+class MediaImageLabel:
+    def __init__(self, result: Result):
+        self.label = result.label
+
+    def serialize(self) -> dict:
+        return dict({'type': 'image-label'}, **self.__dict__)

+ 10 - 0
pycs/interfaces/MediaLabel.py

@@ -0,0 +1,10 @@
+from pycs.database.Label import Label
+
+
+class MediaLabel:
+    def __init__(self, label: Label):
+        self.identifier = label.identifier
+        self.parent = None
+        self.children = []
+        self.reference = label.reference
+        self.name = label.name

+ 60 - 0
pycs/interfaces/MediaStorage.py

@@ -0,0 +1,60 @@
+from typing import List
+
+from pycs.database.Database import Database
+from pycs.frontend.notifications.NotificationList import NotificationList
+from pycs.interfaces.MediaFileList import MediaFileList
+from pycs.interfaces.MediaLabel import MediaLabel
+
+
+class MediaStorage:
+    """
+    helper class for pipelines to interact with database entities
+    """
+
+    def __init__(self, db: Database, project_id: int, notifications: NotificationList = None):
+        self.__db = db
+        self.__project_id = project_id
+        self.__notifications = notifications
+
+        self.__project = self.__db.project(self.__project_id)
+        self.__collections = self.__project.collections()
+
+    def labels(self) -> List[MediaLabel]:
+        """
+        receive a list of labels
+
+        :return: list of labels
+        """
+        label_list = self.__project.labels()
+        label_dict = {la.identifier: MediaLabel(la) for la in label_list}
+        result = []
+
+        for label in label_list:
+            ml = label_dict[label.identifier]
+
+            if label.parent_id is not None:
+                ml.parent = label_dict[label.parent_id]
+                ml.parent.children.append(ml)
+
+            result.append(ml)
+
+        return result
+
+    def labels_tree(self) -> List[MediaLabel]:
+        """
+        receive a tree of labels
+
+        :return: list of root-level labels (parent is None)
+        """
+        return list(filter(
+            lambda ml: ml.parent is None,
+            self.labels()
+        ))
+
+    def files(self) -> MediaFileList:
+        """
+        get a FileList object to filter and receive MediaFile elements
+
+        :return: FileList
+        """
+        return MediaFileList(self.__project, self.__notifications)

+ 6 - 64
pycs/interfaces/Pipeline.py

@@ -1,7 +1,7 @@
 from typing import List
 
-from pycs.interfaces.AnnotatedMediaFile import AnnotatedMediaFile
 from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaStorage import MediaStorage
 
 
 class Pipeline:
@@ -57,77 +57,19 @@ class Pipeline:
             'autoselect': autoselect
         }
 
-    def execute(self, file: MediaFile) -> List[dict]:
+    def execute(self, storage: MediaStorage, file: MediaFile):
         """
-        receive a job, execute it and return the predicted result
+        receive a file, create predictions and add them to the object
 
+        :param storage: database abstraction object
         :param file: which should be analyzed
-        :return:
         """
         raise NotImplementedError
 
-    @staticmethod
-    def create_collection_result(reference: str) -> dict:
-        """
-        create a collection result dictionary
-
-        :param reference: use None to remove this file's collection
-        :return: dict
-        """
-        return {
-            'type': 'collection',
-            'reference': reference
-        }
-
-    @staticmethod
-    def create_labeled_image_result(label: int) -> dict:
-        """
-        create a labeled-image result dictionary
-
-        :param label: label identifier
-        :return: dict
-        """
-        return {
-            'type': 'labeled-image',
-            'label': label
-        }
-
-    @staticmethod
-    def create_bounding_box_result(x: float, y: float, w: float, h: float,
-                                   label=None, frame=None) -> dict:
-        # pylint: disable=too-many-arguments
-        # pylint: disable=invalid-name
-        """
-        create a bounding-box result dictionary
-
-        :param x: relative x coordinate [0, 1]
-        :param y: relative y coordinate [0, 1]
-        :param w: relative width [0, 1]
-        :param h: relative height [0, 1]
-        :param label: label identifier
-        :param frame: frame index
-        :return: dict
-        """
-        result = {
-            'type': 'bounding-box',
-            'x': x,
-            'y': y,
-            'w': w,
-            'h': h
-        }
-
-        if label is not None:
-            result['label'] = label
-        if frame is not None:
-            result['frame'] = frame
-
-        return result
-
-    def fit(self, files: List[AnnotatedMediaFile]):
+    def fit(self, storage: MediaStorage):
         """
         receive a list of annotated media files and adapt the underlying model
 
-        :param files: list of annotated media files
-        :return:
+        :param storage: database abstraction object
         """
         raise NotImplementedError

+ 16 - 3
pycs/jobs/JobRunner.py

@@ -2,7 +2,7 @@ from time import time
 from types import GeneratorType
 from typing import Callable, List, Generator, Optional, Any
 
-from eventlet import tpool, spawn_n
+from eventlet import spawn_n, GreenPool
 from eventlet.event import Event
 from eventlet.queue import Queue
 
@@ -152,7 +152,16 @@ class JobRunner:
         # return job object
         return job
 
+    @staticmethod
+    def __next(it):
+        try:
+            return next(it)
+        except StopIteration as e:
+            return e
+
     def __run(self):
+        pool = GreenPool(1)
+
         while True:
             # get execution function and job from queue
             group, executable, job, progress_fun, result_fun, result_event, args, kwargs \
@@ -166,7 +175,7 @@ class JobRunner:
                 callback(job)
 
             # run function and track progress
-            generator = tpool.execute(executable, *args, **kwargs)
+            generator = pool.spawn(executable, *args, **kwargs).wait()
             result = generator
 
             if isinstance(generator, GeneratorType):
@@ -175,7 +184,11 @@ class JobRunner:
                 try:
                     while True:
                         # run until next progress event
-                        progress = tpool.execute(next, iterator)
+                        progress = pool.spawn(self.__next, iterator).wait()
+
+                        # raise StopIteration if return is of this type
+                        if isinstance(progress, StopIteration):
+                            raise progress
 
                         # execute progress function
                         if progress_fun is not None:

+ 23 - 0
webui/src/components/media/paginated-media.vue

@@ -41,6 +41,7 @@ export default {
     window.addEventListener('wheel', this.scroll);
 
     this.$root.socket.on('create-file', this.change);
+    this.$root.socket.on('edit-file', this.edit);
     this.$root.socket.on('remove-file', this.change);
 
     this.get(() => this.$emit('click', this.images[0]));
@@ -50,6 +51,7 @@ export default {
     window.removeEventListener('wheel', this.scroll);
 
     this.$root.socket.off('create-file', this.change);
+    this.$root.socket.off('edit-file', this.edit);
     this.$root.socket.off('remove-file', this.change);
   },
   data: function () {
@@ -85,6 +87,27 @@ export default {
       if (file.project_id === this.$root.project.identifier)
         this.get();
     },
+    edit: function (file) {
+      // edited file is now in the current collection
+      // filter is stored as str, so '==' is intentional
+      if (file.collection_id == this.filter) {
+        this.get();
+        return;
+      }
+
+      // edited file is in the current image list
+      for (let image of this.images) {
+        if (image.identifier === file.identifier) {
+          this.get(() => {
+            // click the first image if the current shown was removed
+            if (this.current.identifier === file.identifier) {
+              this.$emit('click', this.images[0]);
+            }
+          });
+          return;
+        }
+      }
+    },
     deleteElement: function (element) {
       this.$root.socket.post(`/data/${element.identifier}/remove`, {remove: true});
     },