Browse Source

added some client tests and reworked threaded execution

Dimitri Korsch 3 năm trước cách đây
mục cha
commit
b02d8c669a

+ 1 - 3
pycs/database/Database.py

@@ -18,14 +18,12 @@ class Database:
     opens an sqlite database and allows to access several objects
     """
 
-    def __init__(self, path: str = ':memory:', initialization=True, discovery=True):
+    def __init__(self, initialization=True, discovery=True):
         """
         opens or creates a given sqlite database and creates all required tables
 
         :param path: path to sqlite database
         """
-        # save properties
-        self.path = path
 
         if discovery:
             # run discovery modules

+ 14 - 3
pycs/frontend/WebServer.py

@@ -62,9 +62,7 @@ class WebServer:
         self.app = app
 
         # initialize database
-        db_file = settings["database"]
-        self.logger.info(f'Loading database from \"{db_file}\"')
-        self.db = Database(db_file)
+        self.db = Database()
 
         # start job runner
         self.logger.info('Starting job runner... ')
@@ -90,6 +88,19 @@ class WebServer:
         self.define_routes()
         self.logger.info("Server initialized")
 
+    def start_runner(self):
+        self.jobs.start()
+        self.pipelines.start()
+
+    def stop_runner(self):
+        self.jobs.stop()
+        self.pipelines.stop()
+
+    def wait_for_runner(self):
+        self.jobs.wait_for_empty_queue()
+        self.pipelines.wait_for_empty_queue()
+
+
     @property
     def logger(self):
         return self.app.logger

+ 29 - 23
pycs/frontend/endpoints/projects/CreateProject.py

@@ -3,10 +3,12 @@ from os import mkdir
 from os import path
 from shutil import copytree
 from uuid import uuid1
+from pathlib import Path
 
 from flask import make_response, request, abort
 from flask.views import View
 
+from pycs import app
 from pycs.database.Database import Database
 from pycs.database.Project import Project
 from pycs.frontend.endpoints.projects.ExecuteExternalStorage import ExecuteExternalStorage
@@ -29,65 +31,69 @@ class CreateProject(View):
         self.nm = nm
         self.jobs = jobs
 
+    @property
+    def project_folder(self):
+        return app.config["TEST_PROJECTS_DIR"] if app.config["TESTING"] else 'projects'
+
     def dispatch_request(self):
         # extract request data
         data = request.get_json(force=True)
 
-        if 'name' not in data or 'description' not in data:
+        if None in [data.get('name'), data.get('description')]:
             return abort(400, "name and description information missing!")
 
         name = data['name']
         description = data['description']
+        model_id = data['model']
+        label_provider_id = data['label']
+        data_folder = data['external']
+        external_data = data_folder is not None
 
         # start transaction
         with self.db:
             # find model
-            model_id = int(data['model'])
-            model = self.db.model(model_id)
+            model = self.db.model(int(model_id))
 
             if model is None:
                 return abort(404, "Model not found")
 
             # find label provider
-            if data['label'] is None:
+            if label_provider_id is None:
                 label_provider = None
             else:
-                label_provider_id = int(data['label'])
-                label_provider = self.db.label_provider(label_provider_id)
+                label_provider = self.db.label_provider(int(label_provider_id))
 
                 if label_provider is None:
                     return abort(404, "Label provider not found")
 
             # create project folder
-            project_folder = path.join('projects', str(uuid1()))
-            mkdir(project_folder)
+            project_folder = Path(self.project_folder, str(uuid1()))
+            project_folder.mkdir(parents=True)
 
-            temp_folder = path.join(project_folder, 'temp')
-            mkdir(temp_folder)
+            temp_folder = project_folder / 'temp'
+            temp_folder.mkdir()
 
             # check project data directory
-            if data['external'] is None:
-                external_data = False
-                data_folder = path.join(project_folder, 'data')
+            if external_data:
+                # check if exists
+                if not path.exists(data_folder):
+                    return abort(400, "Data folder does not exist!")
 
-                mkdir(data_folder)
             else:
-                external_data = True
-                data_folder = data['external']
+                data_folder = project_folder / 'data'
+                data_folder.mkdir()
 
-                # check if exists
-                if not path.exists(data_folder):
-                    return abort(400)
 
             # copy model to project folder
-            model_folder = path.join(project_folder, 'model')
-            copytree(model.root_folder, model_folder)
+            model_folder = project_folder / 'model'
+            copytree(model.root_folder, str(model_folder))
 
-            model, _ = model.copy_to(f'{model.name} ({name})', model_folder)
+            model, _ = model.copy_to(f'{model.name} ({name})', str(model_folder))
 
             # create entry in database
             project = self.db.create_project(name, description, model, label_provider,
-                                             project_folder, external_data, data_folder)
+                                             str(project_folder), external_data,
+                                             str(data_folder))
 
         # execute label provider and add labels to project
         if label_provider is not None:

+ 2 - 2
pycs/frontend/endpoints/results/CreateResult.py

@@ -31,14 +31,14 @@ class CreateResult(View):
         if 'label' in request_data and request_data['label']:
             label = request_data['label']
         elif request_data['type'] == 'labeled-image':
-            return abort(400)
+            return abort(400, "label missing for the labeled-image annotation")
         else:
             label = None
 
         if 'data' in request_data and request_data['data']:
             data = request_data['data']
         elif request_data['type'] == 'bounding-box':
-            return abort(400)
+            return abort(400, "data missing for the bounding box annotation")
         else:
             data = {}
 

+ 95 - 11
pycs/jobs/JobRunner.py

@@ -1,29 +1,31 @@
-from concurrent.futures import ThreadPoolExecutor
+# from concurrent.futures import ThreadPoolExecutor
 from time import time
 from types import GeneratorType
 from typing import Callable, List, Generator, Optional, Any
 
-from eventlet import spawn_n, tpool
+# import eventlet
+# from eventlet import spawn, spawn_n, tpool
 from eventlet.event import Event
-from eventlet.queue import Queue
+
 
 from pycs.database.Project import Project
 from pycs.jobs.Job import Job
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
 
+from pycs.util.green_worker import GreenWorker
 
-class JobRunner:
+class JobRunner(GreenWorker):
     """
     run jobs in a thread pool, but track progress and process results in eventlet queue
     """
 
     # pylint: disable=too-many-arguments
     def __init__(self):
+        super().__init__()
         self.__jobs = []
         self.__groups = {}
 
-        self.__executor = ThreadPoolExecutor(1)
-        self.__queue = Queue()
+        # self.__executor = ThreadPoolExecutor(1)
 
         self.__create_listeners = []
         self.__start_listeners = []
@@ -31,8 +33,6 @@ class JobRunner:
         self.__finish_listeners = []
         self.__remove_listeners = []
 
-        spawn_n(self.__run)
-
     def list(self) -> List[Job]:
         """
         get a list of all jobs including finished ones
@@ -150,13 +150,95 @@ class JobRunner:
             callback(job)
 
         # add to execution queue
-        self.__queue.put((group, executable, job, progress, result, result_event, args, kwargs))
+        # self.__queue.put((group, executable, job, progress, result, result_event, args, kwargs))
+        self.queue.put((group, executable, job, progress, result, result_event, args, kwargs))
 
         # return job object
         return job
 
+    def process_iterator(self, iterator, job, progress_fun):
+        try:
+            iterator = iter(generator)
+
+            while True:
+                # run until next progress event
+                # future = self.__executor.submit(next, iterator)
+                # progress = tpool.execute(future.result)
+
+                # progress = future.result()
+                progress = next(iterator)
+
+                # execute progress function
+                if progress_fun is not None:
+                    if isinstance(progress, tuple):
+                        progress = progress_fun(*progress)
+                    else:
+                        progress = progress_fun(progress)
+
+                # execute progress listeners
+                job.progress = progress
+                job.updated = int(time())
+
+                for callback in self.__progress_listeners:
+                    callback(job)
+
+        except StopIteration as stop_iteration_exception:
+            return stop_iteration_exception.value
+
+
+    # done in a separate green thread
+    def work(self, group, executable, job, progress_fun, result_fun, result_event, args, kwargs):
+        # execute start listeners
+        job.started = int(time())
+        job.updated = int(time())
+
+        for callback in self.__start_listeners:
+            callback(job)
+
+        try:
+            result = generator = executable(*args, **kwargs)
+            if isinstance(generator, GeneratorType):
+                result = self.process_iterator(iterator, job, progress_fun)
+
+            # update progress
+            job.progress = 1
+            job.updated = int(time())
+
+            for callback in self.__progress_listeners:
+                callback(job)
+
+            # execute result function
+            if result_fun is not None:
+                if isinstance(result, tuple):
+                    result_fun(*result)
+                else:
+                    result_fun(result)
+
+            # execute event
+            if result_event is not None:
+                result_event.send(result)
+
+        # save exceptions to show in ui
+        except Exception as e:
+            import traceback
+            traceback.print_exc()
+            job.exception = f'{type(e).__name__} ({str(e)})'
+
+        # remove from group dict
+        if group is not None:
+            del self.__groups[group]
+
+        # finish job
+        job.finished = int(time())
+        job.updated = int(time())
+
+        for callback in self.__finish_listeners:
+            callback(job)
+
     def __run(self):
+
         while True:
+
             # get execution function and job from queue
             group, executable, job, progress_fun, result_fun, result_event, args, kwargs \
                 = self.__queue.get(block=True)
@@ -170,9 +252,9 @@ class JobRunner:
 
             # run function and track progress
             try:
+                # result = generator = executable(*args, **kwargs)
                 future = self.__executor.submit(executable, *args, **kwargs)
-                generator = tpool.execute(future.result)
-                result = generator
+                result = generator = tpool.execute(future.result)
 
                 if isinstance(generator, GeneratorType):
                     iterator = iter(generator)
@@ -182,6 +264,8 @@ class JobRunner:
                             # run until next progress event
                             future = self.__executor.submit(next, iterator)
                             progress = tpool.execute(future.result)
+                            # progress = next(iterator)
+
 
                             # execute progress function
                             if progress_fun is not None:

+ 60 - 12
pycs/util/PipelineCache.py

@@ -1,52 +1,59 @@
+import eventlet
+
 from queue import Queue
 from threading import Lock
 from time import time, sleep
 
 from eventlet import tpool, spawn_n
+from collections import namedtuple
 
 from pycs.database.Project import Project
 from pycs.interfaces.Pipeline import Pipeline
 from pycs.jobs.JobRunner import JobRunner
 from pycs.util.PipelineUtil import load_from_root_folder
+from pycs.util.green_worker import GreenWorker
 
+PipelineEntry = namedtuple("PipelineEntry", "counter pipeline project_id")
 
-class PipelineCache:
+class PipelineCache(GreenWorker):
     CLOSE_TIMER = 120
 
     def __init__(self, jobs: JobRunner):
+        super().__init__()
         self.__jobs = jobs
 
-        self.__pipelines = {}
-        self.__queue = Queue()
+        self.__pipelines: dict[PipelineEntry] = {}
+        # self.__queue = Queue()
         self.__lock = Lock()
 
-        spawn_n(self.__run)
+        # self.__greenpool = eventlet.GreenPool()
+        # spawn_n(self.__run)
 
     def load_from_root_folder(self, project: Project, root_folder: str) -> Pipeline:
         """
         load configuration.json and create an instance from the included code object
 
-        :param project: associated project
+        :param projeventletect: associated project
         :param root_folder: path to model root folder
         :return: Pipeline instance
         """
         # check if instance is cached
         with self.__lock:
             if root_folder in self.__pipelines:
-                instance = self.__pipelines[root_folder]
+                entry: PipelineEntry = self.__pipelines[root_folder]
 
                 # increase reference counter
-                instance[0] += 1
+                entry.counter += 1
 
-                # return instance
-                return instance[1]
+                # return entry
+                return entry.pipeline
 
         # load pipeline
         pipeline = load_from_root_folder(root_folder)
 
         # save instance to cache
         with self.__lock:
-            self.__pipelines[root_folder] = [1, pipeline, project.id]
+            self.__pipelines[root_folder] = PipelineEntry(1, pipeline, project.id)
 
         # return
         return pipeline
@@ -66,7 +73,48 @@ class PipelineCache:
 
         # start timeout
         timestamp = time()
-        self.__queue.put((root_folder, timestamp))
+        self.queue.put((root_folder, timestamp))
+
+    # executed as coroutine in the main thread
+    def start_work(self, root_folder, timestamp):
+
+        # delegate to work method in a separate thread
+        pipeline, project_id = super().run(root_folder, timestamp)
+
+        project = Project.query.get(project_id)
+
+        # create job to close pipeline
+        self.__jobs.run(project,
+                        'Model Interaction',
+                        f'{project.name} (close pipeline)',
+                        f'{project.name}/model-interaction',
+                        pipeline.close
+                        )
+
+    # executed in a separate thread
+    def work(self, root_folder, timestamp):
+        while True:
+            # sleep if needed
+            delay = int(timestamp + self.CLOSE_TIMER - time())
+
+            if delay > 0:
+                eventlet.sleep(delay)
+
+            # lock and access __pipelines
+            with self.__lock:
+                instance: PipelineEntry = self.__pipelines[root_folder]
+
+                # reference counter greater than 1
+                if instance.counter > 1:
+                    # decrease reference counter
+                    instance.counter -= 1
+                    continue
+
+                # reference counter equals 1
+                else:
+                    # delete instance from __pipelines and return to call `close` function
+                    del self.__pipelines[root_folder]
+                    return entry.pipeline, entry.project_id
 
     def __get(self):
         while True:
@@ -77,7 +125,7 @@ class PipelineCache:
             delay = int(timestamp + self.CLOSE_TIMER - time())
 
             if delay > 0:
-                sleep(delay)
+                eventlet.sleep(delay)
 
             # lock and access __pipelines
             with self.__lock:

+ 98 - 0
pycs/util/green_worker.py

@@ -0,0 +1,98 @@
+import abc
+import atexit
+import eventlet
+import threading
+
+# from concurrent.futures import ThreadPoolExecutor
+from eventlet import tpool
+
+class GreenWorker(abc.ABC):
+    def __init__(self):
+        super(GreenWorker, self).__init__()
+
+        self.pool = eventlet.GreenPool()
+        self.stop_event = eventlet.Event()
+        self.pool_finished = eventlet.Event()
+        self.queue = eventlet.Queue()
+        # self.executor = ThreadPoolExecutor()
+
+        self.__sleep_time = 0.1
+        self.__running = False
+
+
+
+    def start(self):
+        if self.__running:
+            return
+        # self._thread = self.pool.
+        eventlet.spawn(self.__run__)
+        self.__running = True
+
+    def stop(self):
+        if self.stop_event.has_result():
+            # do not re-send this event
+            return
+
+        # print(self, self.stop_event, "sending stop_event")
+        self.stop_event.send(True)
+
+        self.wait_for_empty_queue()
+
+        # self.pool.waitall()
+        pool_id = self.pool_finished.wait()
+        self.pool_finished.reset()
+        # print(f"pool_id #{pool_id} finished")
+        self.__running = False
+
+    def wait_for_empty_queue(self):
+        while not self.queue.empty():
+            eventlet.sleep(self.__sleep_time)
+            continue
+
+    def __run__(self):
+        while True:
+            if self.queue.empty():
+                # print("Queue was empty, checking for stop")
+                if self.stop_event.ready() and \
+                    self.stop_event.wait(self.__sleep_time):
+                    # print("Stop event received")
+                    self.stop_event.reset()
+                    break
+                else:
+                    eventlet.sleep(self.__sleep_time)
+                    # print("no stop event received")
+                    continue
+
+            args = self.queue.get(block=True)
+            self.start_work(*args)
+
+        # print(self.pool_finished)
+        # if not self.pool_finished.has_result():
+        self.pool_finished.send(threading.get_ident())
+
+    def start_work(self, *args):
+        return tpool.execute(self.work, *args)
+
+    @abc.abstractmethod
+    def work(self, *args):
+        pass
+
+
+
+if __name__ == '__main__':
+    import _thread as thread
+    class Foo(GreenWorker):
+        def work(self, value):
+            print(thread.get_ident(), value)
+
+
+    worker = Foo()
+    print("Main:", thread.get_ident())
+
+    worker.start()
+    worker.queue.put(("value1",))
+    worker.queue.put(("value2",))
+    worker.queue.put(("value3",))
+
+    # eventlet.sleep(.01)
+    worker.stop()

+ 2 - 0
test/__init__.py

@@ -0,0 +1,2 @@
+from test.test_client import ClientTests
+from test.test_database import DatabaseTests

+ 48 - 0
test/base.py

@@ -0,0 +1,48 @@
+
+import os
+import shutil
+import unittest
+
+from pycs import app
+from pycs import db
+from pycs import settings
+from pycs.frontend.WebServer import WebServer
+from pycs.database.Database import Database
+
+server = None
+
+class BaseTestCase(unittest.TestCase):
+    def setUp(self, discovery: bool = True):
+        global server
+        app.config["TESTING"] = True
+        self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
+        app.config["WTF_CSRF_ENABLED"] = False
+        app.config["DEBUG"] = False
+        app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
+
+        db.create_all()
+
+        self.client = app.test_client()
+
+        if server is None:
+            server = WebServer(app, settings)
+
+        server.start_runner()
+
+        # create database
+        self.database = Database(discovery=discovery)
+
+    def tearDown(self):
+        global server
+
+        server.stop_runner()
+
+        if os.path.exists(self.projects_dir):
+            shutil.rmtree(self.projects_dir)
+
+        db.drop_all()
+
+    def wait_for_coroutines(self):
+        server.wait_for_runner()
+
+

+ 143 - 0
test/test_client.py

@@ -0,0 +1,143 @@
+import io
+import time
+import eventlet
+
+from test.base import BaseTestCase
+
+from pycs.database.File import File
+from pycs.database.Result import Result
+from pycs.database.Label import Label
+from pycs.database.Project import Project
+
+class ClientTests(BaseTestCase):
+
+    def _post(self, url, status_code=200, content_type=None, json=None, data=None):
+        response = self.client.post(url,
+            json=json,
+            data=data,
+            follow_redirects=True,
+            content_type=content_type,
+        )
+
+        self.assertEqual(response.status_code, 200, response.get_data().decode())
+        return response
+
+    def test_project_creation(self):
+
+        self.assertEqual(0, Project.query.count())
+        self.assertEqual(0, Label.query.count())
+
+        self._post(
+            "/projects",
+            json=dict(
+                name="some name",
+                description="some description",
+                model=1,
+                label=2,
+                external=None,
+            )
+        )
+        self.assertEqual(1, Project.query.count())
+
+        project = Project.query.first()
+
+        self.assertIsNotNone(project)
+        self.assertIsNotNone(project.model)
+        self.assertIsNotNone(project.label_provider)
+
+        self.wait_for_coroutines()
+        self.assertNotEqual(0, Label.query.count())
+
+    def test_adding_file_with_result(self):
+
+        self._post("/projects",
+            json=dict(
+                name="some name",
+                description="some description",
+                model=1,
+                label=2,
+                external=None,
+            )
+        )
+        self.assertEqual(1, Project.query.count())
+        project = Project.query.first()
+
+        self.wait_for_coroutines()
+
+        self.assertEqual(0, File.query.count())
+        self._post(f"/projects/{project.id}/data",
+            data=dict(file=(io.BytesIO(b"some content"), "image.jpg")),
+            content_type="multipart/form-data",
+        )
+
+        self.assertEqual(1, File.query.count())
+        file = File.query.first()
+
+        self.assertEqual(0, Result.query.count())
+        self._post(f"data/{file.id}/results",
+            json=dict(
+                type="bounding-box",
+                data=dict(x0=0, x1=0, y0=0, y1=0),
+                label=2,
+            )
+        )
+        self.assertEqual(1, Result.query.count())
+
+    def test_cascade_after_project_removal(self):
+
+        self.assertEqual(0, File.query.count())
+        self.assertEqual(0, Result.query.count())
+        self.assertEqual(0, Label.query.count())
+        self.assertEqual(0, Project.query.count())
+
+        self._post("/projects",
+            json=dict(
+                name="some name",
+                description="some description",
+                model=1,
+                label=2,
+                external=None,
+            )
+        )
+        project = Project.query.first()
+        project_id = project.id
+
+        self.wait_for_coroutines()
+        self._post(f"/projects/{project_id}/data",
+            data=dict(file=(io.BytesIO(b"some content"), "image.jpg")),
+            content_type="multipart/form-data",
+        )
+        file = File.query.first()
+        file_id = file.id
+
+        self.wait_for_coroutines()
+        self._post(f"data/{file_id}/results",
+            json=dict(
+                type="bounding-box",
+                data=dict(x0=0, x1=0, y0=0, y1=0),
+                label=2,
+            )
+        )
+
+
+        self.assertNotEqual(0, File.query.count())
+        self.assertNotEqual(0, Result.query.count())
+        self.assertNotEqual(0, Label.query.count())
+        self.assertNotEqual(0, Project.query.count())
+
+        self.wait_for_coroutines()
+        eventlet.sleep(2)
+        self._post(f"/projects/{project_id}/remove",
+            json=dict(remove=True),
+        )
+
+        self.assertEqual(0, Project.query.count())
+        self.assertEqual(0, Label.query.count())
+        self.assertEqual(0, File.query.count())
+        self.assertEqual(0, Result.query.count())
+
+
+
+
+
+

+ 37 - 12
test/test_database.py

@@ -1,19 +1,18 @@
 import unittest
-from contextlib import closing
 
-from pycs import db
 from pycs.database.Database import Database
 from pycs.database.File import File
 from pycs.database.Label import Label
+from pycs.database.Result import Result
 from pycs.database.Model import Model
 from pycs.database.LabelProvider import LabelProvider
 
+from test.base import BaseTestCase
+
+class DatabaseTests(BaseTestCase):
 
-class TestDatabase(unittest.TestCase):
     def setUp(self) -> None:
-        db.create_all()
-        # create database
-        self.database = Database(discovery=False)
+        super().setUp(discovery=False)
 
         # insert default models and label_providers
         with self.database:
@@ -51,10 +50,6 @@ class TestDatabase(unittest.TestCase):
                 data_folder=f'datadir{i}',
             )
 
-    def tearDown(self) -> None:
-        db.drop_all()
-        self.database.close()
-
     def test_models(self):
         models = list(self.database.models())
 
@@ -170,11 +165,41 @@ class TestDatabase(unittest.TestCase):
             self.assertIsNotNone(label)
 
         self.assertEqual(5, Label.query.filter_by(project_id=project.id).count())
-        with self.database:
-            project.remove(commit=False)
+
+        project.remove()
 
         self.assertIsNone(self.database.project(1))
         self.assertEqual(0, Label.query.count())
 
+
+    def test_no_results_after_file_deletion(self):
+
+        project = self.database.project(1)
+        self.assertIsNotNone(project)
+
+        file, is_new = project.add_file(
+            uuid=f"some_string",
+            name=f"some_name",
+            filename=f"some_filename",
+            file_type="image",
+            extension=".jpg",
+            size=42,
+        )
+
+        self.assertIsNotNone(file)
+
+        for i in range(5):
+            result = file.create_result(
+                origin="pipeline",
+                result_type="bounding_box",
+                label=None,
+            )
+
+        self.assertEqual(5, Result.query.count())
+
+        File.query.filter_by(id=file.id).delete()
+        self.assertEqual(0, Result.query.count())
+
+
 if __name__ == '__main__':
     unittest.main()