6
0
Jelajahi Sumber

added some client tests and reworked threaded execution

Dimitri Korsch 3 tahun lalu
induk
melakukan
b02d8c669a

+ 1 - 3
pycs/database/Database.py

@@ -18,14 +18,12 @@ class Database:
     opens an sqlite database and allows to access several objects
     opens an sqlite database and allows to access several objects
     """
     """
 
 
-    def __init__(self, path: str = ':memory:', initialization=True, discovery=True):
+    def __init__(self, initialization=True, discovery=True):
         """
         """
         opens or creates a given sqlite database and creates all required tables
         opens or creates a given sqlite database and creates all required tables
 
 
         :param path: path to sqlite database
         :param path: path to sqlite database
         """
         """
-        # save properties
-        self.path = path
 
 
         if discovery:
         if discovery:
             # run discovery modules
             # run discovery modules

+ 14 - 3
pycs/frontend/WebServer.py

@@ -62,9 +62,7 @@ class WebServer:
         self.app = app
         self.app = app
 
 
         # initialize database
         # initialize database
-        db_file = settings["database"]
-        self.logger.info(f'Loading database from \"{db_file}\"')
-        self.db = Database(db_file)
+        self.db = Database()
 
 
         # start job runner
         # start job runner
         self.logger.info('Starting job runner... ')
         self.logger.info('Starting job runner... ')
@@ -90,6 +88,19 @@ class WebServer:
         self.define_routes()
         self.define_routes()
         self.logger.info("Server initialized")
         self.logger.info("Server initialized")
 
 
+    def start_runner(self):
+        self.jobs.start()
+        self.pipelines.start()
+
+    def stop_runner(self):
+        self.jobs.stop()
+        self.pipelines.stop()
+
+    def wait_for_runner(self):
+        self.jobs.wait_for_empty_queue()
+        self.pipelines.wait_for_empty_queue()
+
+
     @property
     @property
     def logger(self):
     def logger(self):
         return self.app.logger
         return self.app.logger

+ 29 - 23
pycs/frontend/endpoints/projects/CreateProject.py

@@ -3,10 +3,12 @@ from os import mkdir
 from os import path
 from os import path
 from shutil import copytree
 from shutil import copytree
 from uuid import uuid1
 from uuid import uuid1
+from pathlib import Path
 
 
 from flask import make_response, request, abort
 from flask import make_response, request, abort
 from flask.views import View
 from flask.views import View
 
 
+from pycs import app
 from pycs.database.Database import Database
 from pycs.database.Database import Database
 from pycs.database.Project import Project
 from pycs.database.Project import Project
 from pycs.frontend.endpoints.projects.ExecuteExternalStorage import ExecuteExternalStorage
 from pycs.frontend.endpoints.projects.ExecuteExternalStorage import ExecuteExternalStorage
@@ -29,65 +31,69 @@ class CreateProject(View):
         self.nm = nm
         self.nm = nm
         self.jobs = jobs
         self.jobs = jobs
 
 
+    @property
+    def project_folder(self):
+        return app.config["TEST_PROJECTS_DIR"] if app.config["TESTING"] else 'projects'
+
     def dispatch_request(self):
     def dispatch_request(self):
         # extract request data
         # extract request data
         data = request.get_json(force=True)
         data = request.get_json(force=True)
 
 
-        if 'name' not in data or 'description' not in data:
+        if None in [data.get('name'), data.get('description')]:
             return abort(400, "name and description information missing!")
             return abort(400, "name and description information missing!")
 
 
         name = data['name']
         name = data['name']
         description = data['description']
         description = data['description']
+        model_id = data['model']
+        label_provider_id = data['label']
+        data_folder = data['external']
+        external_data = data_folder is not None
 
 
         # start transaction
         # start transaction
         with self.db:
         with self.db:
             # find model
             # find model
-            model_id = int(data['model'])
-            model = self.db.model(model_id)
+            model = self.db.model(int(model_id))
 
 
             if model is None:
             if model is None:
                 return abort(404, "Model not found")
                 return abort(404, "Model not found")
 
 
             # find label provider
             # find label provider
-            if data['label'] is None:
+            if label_provider_id is None:
                 label_provider = None
                 label_provider = None
             else:
             else:
-                label_provider_id = int(data['label'])
-                label_provider = self.db.label_provider(label_provider_id)
+                label_provider = self.db.label_provider(int(label_provider_id))
 
 
                 if label_provider is None:
                 if label_provider is None:
                     return abort(404, "Label provider not found")
                     return abort(404, "Label provider not found")
 
 
             # create project folder
             # create project folder
-            project_folder = path.join('projects', str(uuid1()))
-            mkdir(project_folder)
+            project_folder = Path(self.project_folder, str(uuid1()))
+            project_folder.mkdir(parents=True)
 
 
-            temp_folder = path.join(project_folder, 'temp')
-            mkdir(temp_folder)
+            temp_folder = project_folder / 'temp'
+            temp_folder.mkdir()
 
 
             # check project data directory
             # check project data directory
-            if data['external'] is None:
-                external_data = False
-                data_folder = path.join(project_folder, 'data')
+            if external_data:
+                # check if exists
+                if not path.exists(data_folder):
+                    return abort(400, "Data folder does not exist!")
 
 
-                mkdir(data_folder)
             else:
             else:
-                external_data = True
-                data_folder = data['external']
+                data_folder = project_folder / 'data'
+                data_folder.mkdir()
 
 
-                # check if exists
-                if not path.exists(data_folder):
-                    return abort(400)
 
 
             # copy model to project folder
             # copy model to project folder
-            model_folder = path.join(project_folder, 'model')
-            copytree(model.root_folder, model_folder)
+            model_folder = project_folder / 'model'
+            copytree(model.root_folder, str(model_folder))
 
 
-            model, _ = model.copy_to(f'{model.name} ({name})', model_folder)
+            model, _ = model.copy_to(f'{model.name} ({name})', str(model_folder))
 
 
             # create entry in database
             # create entry in database
             project = self.db.create_project(name, description, model, label_provider,
             project = self.db.create_project(name, description, model, label_provider,
-                                             project_folder, external_data, data_folder)
+                                             str(project_folder), external_data,
+                                             str(data_folder))
 
 
         # execute label provider and add labels to project
         # execute label provider and add labels to project
         if label_provider is not None:
         if label_provider is not None:

+ 2 - 2
pycs/frontend/endpoints/results/CreateResult.py

@@ -31,14 +31,14 @@ class CreateResult(View):
         if 'label' in request_data and request_data['label']:
         if 'label' in request_data and request_data['label']:
             label = request_data['label']
             label = request_data['label']
         elif request_data['type'] == 'labeled-image':
         elif request_data['type'] == 'labeled-image':
-            return abort(400)
+            return abort(400, "label missing for the labeled-image annotation")
         else:
         else:
             label = None
             label = None
 
 
         if 'data' in request_data and request_data['data']:
         if 'data' in request_data and request_data['data']:
             data = request_data['data']
             data = request_data['data']
         elif request_data['type'] == 'bounding-box':
         elif request_data['type'] == 'bounding-box':
-            return abort(400)
+            return abort(400, "data missing for the bounding box annotation")
         else:
         else:
             data = {}
             data = {}
 
 

+ 95 - 11
pycs/jobs/JobRunner.py

@@ -1,29 +1,31 @@
-from concurrent.futures import ThreadPoolExecutor
+# from concurrent.futures import ThreadPoolExecutor
 from time import time
 from time import time
 from types import GeneratorType
 from types import GeneratorType
 from typing import Callable, List, Generator, Optional, Any
 from typing import Callable, List, Generator, Optional, Any
 
 
-from eventlet import spawn_n, tpool
+# import eventlet
+# from eventlet import spawn, spawn_n, tpool
 from eventlet.event import Event
 from eventlet.event import Event
-from eventlet.queue import Queue
+
 
 
 from pycs.database.Project import Project
 from pycs.database.Project import Project
 from pycs.jobs.Job import Job
 from pycs.jobs.Job import Job
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
 from pycs.jobs.JobGroupBusyException import JobGroupBusyException
 
 
+from pycs.util.green_worker import GreenWorker
 
 
-class JobRunner:
+class JobRunner(GreenWorker):
     """
     """
     run jobs in a thread pool, but track progress and process results in eventlet queue
     run jobs in a thread pool, but track progress and process results in eventlet queue
     """
     """
 
 
     # pylint: disable=too-many-arguments
     # pylint: disable=too-many-arguments
     def __init__(self):
     def __init__(self):
+        super().__init__()
         self.__jobs = []
         self.__jobs = []
         self.__groups = {}
         self.__groups = {}
 
 
-        self.__executor = ThreadPoolExecutor(1)
-        self.__queue = Queue()
+        # self.__executor = ThreadPoolExecutor(1)
 
 
         self.__create_listeners = []
         self.__create_listeners = []
         self.__start_listeners = []
         self.__start_listeners = []
@@ -31,8 +33,6 @@ class JobRunner:
         self.__finish_listeners = []
         self.__finish_listeners = []
         self.__remove_listeners = []
         self.__remove_listeners = []
 
 
-        spawn_n(self.__run)
-
     def list(self) -> List[Job]:
     def list(self) -> List[Job]:
         """
         """
         get a list of all jobs including finished ones
         get a list of all jobs including finished ones
@@ -150,13 +150,95 @@ class JobRunner:
             callback(job)
             callback(job)
 
 
         # add to execution queue
         # add to execution queue
-        self.__queue.put((group, executable, job, progress, result, result_event, args, kwargs))
+        # self.__queue.put((group, executable, job, progress, result, result_event, args, kwargs))
+        self.queue.put((group, executable, job, progress, result, result_event, args, kwargs))
 
 
         # return job object
         # return job object
         return job
         return job
 
 
+    def process_iterator(self, iterator, job, progress_fun):
+        try:
+            iterator = iter(generator)
+
+            while True:
+                # run until next progress event
+                # future = self.__executor.submit(next, iterator)
+                # progress = tpool.execute(future.result)
+
+                # progress = future.result()
+                progress = next(iterator)
+
+                # execute progress function
+                if progress_fun is not None:
+                    if isinstance(progress, tuple):
+                        progress = progress_fun(*progress)
+                    else:
+                        progress = progress_fun(progress)
+
+                # execute progress listeners
+                job.progress = progress
+                job.updated = int(time())
+
+                for callback in self.__progress_listeners:
+                    callback(job)
+
+        except StopIteration as stop_iteration_exception:
+            return stop_iteration_exception.value
+
+
+    # done in a separate green thread
+    def work(self, group, executable, job, progress_fun, result_fun, result_event, args, kwargs):
+        # execute start listeners
+        job.started = int(time())
+        job.updated = int(time())
+
+        for callback in self.__start_listeners:
+            callback(job)
+
+        try:
+            result = generator = executable(*args, **kwargs)
+            if isinstance(generator, GeneratorType):
+                result = self.process_iterator(iterator, job, progress_fun)
+
+            # update progress
+            job.progress = 1
+            job.updated = int(time())
+
+            for callback in self.__progress_listeners:
+                callback(job)
+
+            # execute result function
+            if result_fun is not None:
+                if isinstance(result, tuple):
+                    result_fun(*result)
+                else:
+                    result_fun(result)
+
+            # execute event
+            if result_event is not None:
+                result_event.send(result)
+
+        # save exceptions to show in ui
+        except Exception as e:
+            import traceback
+            traceback.print_exc()
+            job.exception = f'{type(e).__name__} ({str(e)})'
+
+        # remove from group dict
+        if group is not None:
+            del self.__groups[group]
+
+        # finish job
+        job.finished = int(time())
+        job.updated = int(time())
+
+        for callback in self.__finish_listeners:
+            callback(job)
+
     def __run(self):
     def __run(self):
+
         while True:
         while True:
+
             # get execution function and job from queue
             # get execution function and job from queue
             group, executable, job, progress_fun, result_fun, result_event, args, kwargs \
             group, executable, job, progress_fun, result_fun, result_event, args, kwargs \
                 = self.__queue.get(block=True)
                 = self.__queue.get(block=True)
@@ -170,9 +252,9 @@ class JobRunner:
 
 
             # run function and track progress
             # run function and track progress
             try:
             try:
+                # result = generator = executable(*args, **kwargs)
                 future = self.__executor.submit(executable, *args, **kwargs)
                 future = self.__executor.submit(executable, *args, **kwargs)
-                generator = tpool.execute(future.result)
-                result = generator
+                result = generator = tpool.execute(future.result)
 
 
                 if isinstance(generator, GeneratorType):
                 if isinstance(generator, GeneratorType):
                     iterator = iter(generator)
                     iterator = iter(generator)
@@ -182,6 +264,8 @@ class JobRunner:
                             # run until next progress event
                             # run until next progress event
                             future = self.__executor.submit(next, iterator)
                             future = self.__executor.submit(next, iterator)
                             progress = tpool.execute(future.result)
                             progress = tpool.execute(future.result)
+                            # progress = next(iterator)
+
 
 
                             # execute progress function
                             # execute progress function
                             if progress_fun is not None:
                             if progress_fun is not None:

+ 60 - 12
pycs/util/PipelineCache.py

@@ -1,52 +1,59 @@
+import eventlet
+
 from queue import Queue
 from queue import Queue
 from threading import Lock
 from threading import Lock
 from time import time, sleep
 from time import time, sleep
 
 
 from eventlet import tpool, spawn_n
 from eventlet import tpool, spawn_n
+from collections import namedtuple
 
 
 from pycs.database.Project import Project
 from pycs.database.Project import Project
 from pycs.interfaces.Pipeline import Pipeline
 from pycs.interfaces.Pipeline import Pipeline
 from pycs.jobs.JobRunner import JobRunner
 from pycs.jobs.JobRunner import JobRunner
 from pycs.util.PipelineUtil import load_from_root_folder
 from pycs.util.PipelineUtil import load_from_root_folder
+from pycs.util.green_worker import GreenWorker
 
 
+PipelineEntry = namedtuple("PipelineEntry", "counter pipeline project_id")
 
 
-class PipelineCache:
+class PipelineCache(GreenWorker):
     CLOSE_TIMER = 120
     CLOSE_TIMER = 120
 
 
     def __init__(self, jobs: JobRunner):
     def __init__(self, jobs: JobRunner):
+        super().__init__()
         self.__jobs = jobs
         self.__jobs = jobs
 
 
-        self.__pipelines = {}
-        self.__queue = Queue()
+        self.__pipelines: dict[PipelineEntry] = {}
+        # self.__queue = Queue()
         self.__lock = Lock()
         self.__lock = Lock()
 
 
-        spawn_n(self.__run)
+        # self.__greenpool = eventlet.GreenPool()
+        # spawn_n(self.__run)
 
 
     def load_from_root_folder(self, project: Project, root_folder: str) -> Pipeline:
     def load_from_root_folder(self, project: Project, root_folder: str) -> Pipeline:
         """
         """
         load configuration.json and create an instance from the included code object
         load configuration.json and create an instance from the included code object
 
 
-        :param project: associated project
+        :param projeventletect: associated project
         :param root_folder: path to model root folder
         :param root_folder: path to model root folder
         :return: Pipeline instance
         :return: Pipeline instance
         """
         """
         # check if instance is cached
         # check if instance is cached
         with self.__lock:
         with self.__lock:
             if root_folder in self.__pipelines:
             if root_folder in self.__pipelines:
-                instance = self.__pipelines[root_folder]
+                entry: PipelineEntry = self.__pipelines[root_folder]
 
 
                 # increase reference counter
                 # increase reference counter
-                instance[0] += 1
+                entry.counter += 1
 
 
-                # return instance
-                return instance[1]
+                # return entry
+                return entry.pipeline
 
 
         # load pipeline
         # load pipeline
         pipeline = load_from_root_folder(root_folder)
         pipeline = load_from_root_folder(root_folder)
 
 
         # save instance to cache
         # save instance to cache
         with self.__lock:
         with self.__lock:
-            self.__pipelines[root_folder] = [1, pipeline, project.id]
+            self.__pipelines[root_folder] = PipelineEntry(1, pipeline, project.id)
 
 
         # return
         # return
         return pipeline
         return pipeline
@@ -66,7 +73,48 @@ class PipelineCache:
 
 
         # start timeout
         # start timeout
         timestamp = time()
         timestamp = time()
-        self.__queue.put((root_folder, timestamp))
+        self.queue.put((root_folder, timestamp))
+
+    # executed as coroutine in the main thread
+    def start_work(self, root_folder, timestamp):
+
+        # delegate to work method in a separate thread
+        pipeline, project_id = super().run(root_folder, timestamp)
+
+        project = Project.query.get(project_id)
+
+        # create job to close pipeline
+        self.__jobs.run(project,
+                        'Model Interaction',
+                        f'{project.name} (close pipeline)',
+                        f'{project.name}/model-interaction',
+                        pipeline.close
+                        )
+
+    # executed in a separate thread
+    def work(self, root_folder, timestamp):
+        while True:
+            # sleep if needed
+            delay = int(timestamp + self.CLOSE_TIMER - time())
+
+            if delay > 0:
+                eventlet.sleep(delay)
+
+            # lock and access __pipelines
+            with self.__lock:
+                instance: PipelineEntry = self.__pipelines[root_folder]
+
+                # reference counter greater than 1
+                if instance.counter > 1:
+                    # decrease reference counter
+                    instance.counter -= 1
+                    continue
+
+                # reference counter equals 1
+                else:
+                    # delete instance from __pipelines and return to call `close` function
+                    del self.__pipelines[root_folder]
+                    return entry.pipeline, entry.project_id
 
 
     def __get(self):
     def __get(self):
         while True:
         while True:
@@ -77,7 +125,7 @@ class PipelineCache:
             delay = int(timestamp + self.CLOSE_TIMER - time())
             delay = int(timestamp + self.CLOSE_TIMER - time())
 
 
             if delay > 0:
             if delay > 0:
-                sleep(delay)
+                eventlet.sleep(delay)
 
 
             # lock and access __pipelines
             # lock and access __pipelines
             with self.__lock:
             with self.__lock:

+ 98 - 0
pycs/util/green_worker.py

@@ -0,0 +1,98 @@
+import abc
+import atexit
+import eventlet
+import threading
+
+# from concurrent.futures import ThreadPoolExecutor
+from eventlet import tpool
+
+class GreenWorker(abc.ABC):
+    def __init__(self):
+        super(GreenWorker, self).__init__()
+
+        self.pool = eventlet.GreenPool()
+        self.stop_event = eventlet.Event()
+        self.pool_finished = eventlet.Event()
+        self.queue = eventlet.Queue()
+        # self.executor = ThreadPoolExecutor()
+
+        self.__sleep_time = 0.1
+        self.__running = False
+
+
+
+    def start(self):
+        if self.__running:
+            return
+        # self._thread = self.pool.
+        eventlet.spawn(self.__run__)
+        self.__running = True
+
+    def stop(self):
+        if self.stop_event.has_result():
+            # do not re-send this event
+            return
+
+        # print(self, self.stop_event, "sending stop_event")
+        self.stop_event.send(True)
+
+        self.wait_for_empty_queue()
+
+        # self.pool.waitall()
+        pool_id = self.pool_finished.wait()
+        self.pool_finished.reset()
+        # print(f"pool_id #{pool_id} finished")
+        self.__running = False
+
+    def wait_for_empty_queue(self):
+        while not self.queue.empty():
+            eventlet.sleep(self.__sleep_time)
+            continue
+
+    def __run__(self):
+        while True:
+            if self.queue.empty():
+                # print("Queue was empty, checking for stop")
+                if self.stop_event.ready() and \
+                    self.stop_event.wait(self.__sleep_time):
+                    # print("Stop event received")
+                    self.stop_event.reset()
+                    break
+                else:
+                    eventlet.sleep(self.__sleep_time)
+                    # print("no stop event received")
+                    continue
+
+            args = self.queue.get(block=True)
+            self.start_work(*args)
+
+        # print(self.pool_finished)
+        # if not self.pool_finished.has_result():
+        self.pool_finished.send(threading.get_ident())
+
+    def start_work(self, *args):
+        return tpool.execute(self.work, *args)
+
+    @abc.abstractmethod
+    def work(self, *args):
+        pass
+
+
+
+if __name__ == '__main__':
+    import _thread as thread
+    class Foo(GreenWorker):
+        def work(self, value):
+            print(thread.get_ident(), value)
+
+
+    worker = Foo()
+    print("Main:", thread.get_ident())
+
+    worker.start()
+    worker.queue.put(("value1",))
+    worker.queue.put(("value2",))
+    worker.queue.put(("value3",))
+
+    # eventlet.sleep(.01)
+    worker.stop()

+ 2 - 0
test/__init__.py

@@ -0,0 +1,2 @@
+from test.test_client import ClientTests
+from test.test_database import DatabaseTests

+ 48 - 0
test/base.py

@@ -0,0 +1,48 @@
+
+import os
+import shutil
+import unittest
+
+from pycs import app
+from pycs import db
+from pycs import settings
+from pycs.frontend.WebServer import WebServer
+from pycs.database.Database import Database
+
+server = None
+
+class BaseTestCase(unittest.TestCase):
+    def setUp(self, discovery: bool = True):
+        global server
+        app.config["TESTING"] = True
+        self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
+        app.config["WTF_CSRF_ENABLED"] = False
+        app.config["DEBUG"] = False
+        app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
+
+        db.create_all()
+
+        self.client = app.test_client()
+
+        if server is None:
+            server = WebServer(app, settings)
+
+        server.start_runner()
+
+        # create database
+        self.database = Database(discovery=discovery)
+
+    def tearDown(self):
+        global server
+
+        server.stop_runner()
+
+        if os.path.exists(self.projects_dir):
+            shutil.rmtree(self.projects_dir)
+
+        db.drop_all()
+
+    def wait_for_coroutines(self):
+        server.wait_for_runner()
+
+

+ 143 - 0
test/test_client.py

@@ -0,0 +1,143 @@
+import io
+import time
+import eventlet
+
+from test.base import BaseTestCase
+
+from pycs.database.File import File
+from pycs.database.Result import Result
+from pycs.database.Label import Label
+from pycs.database.Project import Project
+
+class ClientTests(BaseTestCase):
+
+    def _post(self, url, status_code=200, content_type=None, json=None, data=None):
+        response = self.client.post(url,
+            json=json,
+            data=data,
+            follow_redirects=True,
+            content_type=content_type,
+        )
+
+        self.assertEqual(response.status_code, 200, response.get_data().decode())
+        return response
+
+    def test_project_creation(self):
+
+        self.assertEqual(0, Project.query.count())
+        self.assertEqual(0, Label.query.count())
+
+        self._post(
+            "/projects",
+            json=dict(
+                name="some name",
+                description="some description",
+                model=1,
+                label=2,
+                external=None,
+            )
+        )
+        self.assertEqual(1, Project.query.count())
+
+        project = Project.query.first()
+
+        self.assertIsNotNone(project)
+        self.assertIsNotNone(project.model)
+        self.assertIsNotNone(project.label_provider)
+
+        self.wait_for_coroutines()
+        self.assertNotEqual(0, Label.query.count())
+
+    def test_adding_file_with_result(self):
+
+        self._post("/projects",
+            json=dict(
+                name="some name",
+                description="some description",
+                model=1,
+                label=2,
+                external=None,
+            )
+        )
+        self.assertEqual(1, Project.query.count())
+        project = Project.query.first()
+
+        self.wait_for_coroutines()
+
+        self.assertEqual(0, File.query.count())
+        self._post(f"/projects/{project.id}/data",
+            data=dict(file=(io.BytesIO(b"some content"), "image.jpg")),
+            content_type="multipart/form-data",
+        )
+
+        self.assertEqual(1, File.query.count())
+        file = File.query.first()
+
+        self.assertEqual(0, Result.query.count())
+        self._post(f"data/{file.id}/results",
+            json=dict(
+                type="bounding-box",
+                data=dict(x0=0, x1=0, y0=0, y1=0),
+                label=2,
+            )
+        )
+        self.assertEqual(1, Result.query.count())
+
+    def test_cascade_after_project_removal(self):
+
+        self.assertEqual(0, File.query.count())
+        self.assertEqual(0, Result.query.count())
+        self.assertEqual(0, Label.query.count())
+        self.assertEqual(0, Project.query.count())
+
+        self._post("/projects",
+            json=dict(
+                name="some name",
+                description="some description",
+                model=1,
+                label=2,
+                external=None,
+            )
+        )
+        project = Project.query.first()
+        project_id = project.id
+
+        self.wait_for_coroutines()
+        self._post(f"/projects/{project_id}/data",
+            data=dict(file=(io.BytesIO(b"some content"), "image.jpg")),
+            content_type="multipart/form-data",
+        )
+        file = File.query.first()
+        file_id = file.id
+
+        self.wait_for_coroutines()
+        self._post(f"data/{file_id}/results",
+            json=dict(
+                type="bounding-box",
+                data=dict(x0=0, x1=0, y0=0, y1=0),
+                label=2,
+            )
+        )
+
+
+        self.assertNotEqual(0, File.query.count())
+        self.assertNotEqual(0, Result.query.count())
+        self.assertNotEqual(0, Label.query.count())
+        self.assertNotEqual(0, Project.query.count())
+
+        self.wait_for_coroutines()
+        eventlet.sleep(2)
+        self._post(f"/projects/{project_id}/remove",
+            json=dict(remove=True),
+        )
+
+        self.assertEqual(0, Project.query.count())
+        self.assertEqual(0, Label.query.count())
+        self.assertEqual(0, File.query.count())
+        self.assertEqual(0, Result.query.count())
+
+
+
+
+
+

+ 37 - 12
test/test_database.py

@@ -1,19 +1,18 @@
 import unittest
 import unittest
-from contextlib import closing
 
 
-from pycs import db
 from pycs.database.Database import Database
 from pycs.database.Database import Database
 from pycs.database.File import File
 from pycs.database.File import File
 from pycs.database.Label import Label
 from pycs.database.Label import Label
+from pycs.database.Result import Result
 from pycs.database.Model import Model
 from pycs.database.Model import Model
 from pycs.database.LabelProvider import LabelProvider
 from pycs.database.LabelProvider import LabelProvider
 
 
+from test.base import BaseTestCase
+
+class DatabaseTests(BaseTestCase):
 
 
-class TestDatabase(unittest.TestCase):
     def setUp(self) -> None:
     def setUp(self) -> None:
-        db.create_all()
-        # create database
-        self.database = Database(discovery=False)
+        super().setUp(discovery=False)
 
 
         # insert default models and label_providers
         # insert default models and label_providers
         with self.database:
         with self.database:
@@ -51,10 +50,6 @@ class TestDatabase(unittest.TestCase):
                 data_folder=f'datadir{i}',
                 data_folder=f'datadir{i}',
             )
             )
 
 
-    def tearDown(self) -> None:
-        db.drop_all()
-        self.database.close()
-
     def test_models(self):
     def test_models(self):
         models = list(self.database.models())
         models = list(self.database.models())
 
 
@@ -170,11 +165,41 @@ class TestDatabase(unittest.TestCase):
             self.assertIsNotNone(label)
             self.assertIsNotNone(label)
 
 
         self.assertEqual(5, Label.query.filter_by(project_id=project.id).count())
         self.assertEqual(5, Label.query.filter_by(project_id=project.id).count())
-        with self.database:
-            project.remove(commit=False)
+
+        project.remove()
 
 
         self.assertIsNone(self.database.project(1))
         self.assertIsNone(self.database.project(1))
         self.assertEqual(0, Label.query.count())
         self.assertEqual(0, Label.query.count())
 
 
+
+    def test_no_results_after_file_deletion(self):
+
+        project = self.database.project(1)
+        self.assertIsNotNone(project)
+
+        file, is_new = project.add_file(
+            uuid=f"some_string",
+            name=f"some_name",
+            filename=f"some_filename",
+            file_type="image",
+            extension=".jpg",
+            size=42,
+        )
+
+        self.assertIsNotNone(file)
+
+        for i in range(5):
+            result = file.create_result(
+                origin="pipeline",
+                result_type="bounding_box",
+                label=None,
+            )
+
+        self.assertEqual(5, Result.query.count())
+
+        File.query.filter_by(id=file.id).delete()
+        self.assertEqual(0, Result.query.count())
+
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     unittest.main()
     unittest.main()