6
0
Kaynağa Gözat

added pipeline tests. fixed some bugs in the FitModel endpoint

Dimitri Korsch 4 yıl önce
ebeveyn
işleme
909eec39d6

+ 3 - 0
pycs/__init__.py

@@ -1,6 +1,9 @@
 import json
 import json
 import sys
 import sys
 import os
 import os
+# pylint: disable=wrong-import-position,wrong-import-order
+import eventlet.tpool
+eventlet.tpool.set_num_threads(2)
 
 
 from pathlib import Path
 from pathlib import Path
 from munch import munchify
 from munch import munchify

+ 36 - 34
pycs/frontend/WebServer.py

@@ -115,24 +115,24 @@ class WebServer:
         self.port = settings.port
         self.port = settings.port
 
 
         # create notification manager
         # create notification manager
-        jobs = JobRunner()
-        pipelines = PipelineCache(jobs)
-        notifications = NotificationManager(self.__sio)
+        self.jobs = JobRunner()
+        self.pipelines = PipelineCache(self.jobs)
+        self.notifications = NotificationManager(self.__sio)
 
 
-        jobs.on_create(notifications.create_job)
-        jobs.on_start(notifications.edit_job)
-        jobs.on_progress(notifications.edit_job)
-        jobs.on_finish(notifications.edit_job)
-        jobs.on_remove(notifications.remove_job)
+        self.jobs.on_create(self.notifications.create_job)
+        self.jobs.on_start(self.notifications.edit_job)
+        self.jobs.on_progress(self.notifications.edit_job)
+        self.jobs.on_finish(self.notifications.edit_job)
+        self.jobs.on_remove(self.notifications.remove_job)
 
 
-        self.define_routes(jobs, notifications, pipelines)
+        self.define_routes()
 
 
         if discovery:
         if discovery:
             Model.discover("models/")
             Model.discover("models/")
             LabelProvider.discover("labels/")
             LabelProvider.discover("labels/")
 
 
 
 
-    def define_routes(self, jobs, notifications, pipelines):
+    def define_routes(self):
         """ defines app routes """
         """ defines app routes """
 
 
         # additional
         # additional
@@ -144,11 +144,11 @@ class WebServer:
         # jobs
         # jobs
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/jobs',
             '/jobs',
-            view_func=ListJobs.as_view('list_jobs', jobs)
+            view_func=ListJobs.as_view('list_jobs', self.jobs)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/jobs/<job_id>/remove',
             '/jobs/<job_id>/remove',
-            view_func=RemoveJob.as_view('remove_job', jobs)
+            view_func=RemoveJob.as_view('remove_job', self.jobs)
         )
         )
 
 
         # models
         # models
@@ -176,19 +176,19 @@ class WebServer:
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/labels',
             '/projects/<int:project_id>/labels',
-            view_func=CreateLabel.as_view('create_label', notifications)
+            view_func=CreateLabel.as_view('create_label', self.notifications)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/remove',
             '/projects/<int:project_id>/labels/<int:label_id>/remove',
-            view_func=RemoveLabel.as_view('remove_label', notifications)
+            view_func=RemoveLabel.as_view('remove_label', self.notifications)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/name',
             '/projects/<int:project_id>/labels/<int:label_id>/name',
-            view_func=EditLabelName.as_view('edit_label_name', notifications)
+            view_func=EditLabelName.as_view('edit_label_name', self.notifications)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/labels/<int:label_id>/parent',
             '/projects/<int:project_id>/labels/<int:label_id>/parent',
-            view_func=EditLabelParent.as_view('edit_label_parent', notifications)
+            view_func=EditLabelParent.as_view('edit_label_parent', self.notifications)
         )
         )
 
 
         # collections
         # collections
@@ -204,7 +204,7 @@ class WebServer:
         # data
         # data
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/data',
             '/projects/<int:project_id>/data',
-            view_func=UploadFile.as_view('upload_file', notifications)
+            view_func=UploadFile.as_view('upload_file', self.notifications)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/data',
             '/projects/<int:project_id>/data',
@@ -216,7 +216,7 @@ class WebServer:
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/data/<int:file_id>/remove',
             '/data/<int:file_id>/remove',
-            view_func=RemoveFile.as_view('remove_file', notifications)
+            view_func=RemoveFile.as_view('remove_file', self.notifications)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/data/<int:file_id>',
             '/data/<int:file_id>',
@@ -246,28 +246,28 @@ class WebServer:
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/data/<int:file_id>/results',
             '/data/<int:file_id>/results',
-            view_func=CreateResult.as_view('create_result', notifications)
+            view_func=CreateResult.as_view('create_result', self.notifications)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/data/<int:file_id>/reset',
             '/data/<int:file_id>/reset',
-            view_func=ResetResults.as_view('reset_results', notifications)
+            view_func=ResetResults.as_view('reset_results', self.notifications)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/results/<int:result_id>/remove',
             '/results/<int:result_id>/remove',
-            view_func=RemoveResult.as_view('remove_result', notifications)
+            view_func=RemoveResult.as_view('remove_result', self.notifications)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/results/<int:result_id>/confirm',
             '/results/<int:result_id>/confirm',
-            view_func=ConfirmResult.as_view('confirm_result', notifications)
+            view_func=ConfirmResult.as_view('confirm_result', self.notifications)
         )
         )
 
 
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/results/<int:result_id>/label',
             '/results/<int:result_id>/label',
-            view_func=EditResultLabel.as_view('edit_result_label', notifications)
+            view_func=EditResultLabel.as_view('edit_result_label', self.notifications)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/results/<int:result_id>/data',
             '/results/<int:result_id>/data',
-            view_func=EditResultData.as_view('edit_result_data', notifications)
+            view_func=EditResultData.as_view('edit_result_data', self.notifications)
         )
         )
 
 
         # projects
         # projects
@@ -277,43 +277,45 @@ class WebServer:
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects',
             '/projects',
-            view_func=CreateProject.as_view('create_project', notifications, jobs)
+            view_func=CreateProject.as_view('create_project', self.notifications, self.jobs)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/label_provider',
             '/projects/<int:project_id>/label_provider',
-            view_func=ExecuteLabelProvider.as_view('execute_label_provider', notifications, jobs)
+            view_func=ExecuteLabelProvider.as_view('execute_label_provider',
+                                                   self.notifications, self.jobs)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/external_storage',
             '/projects/<int:project_id>/external_storage',
             view_func=ExecuteExternalStorage.as_view('execute_external_storage',
             view_func=ExecuteExternalStorage.as_view('execute_external_storage',
-                                                     notifications, jobs)
+                                                     self.notifications, self.jobs)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/remove',
             '/projects/<int:project_id>/remove',
-            view_func=RemoveProject.as_view('remove_project', notifications)
+            view_func=RemoveProject.as_view('remove_project', self.notifications)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/name',
             '/projects/<int:project_id>/name',
-            view_func=EditProjectName.as_view('edit_project_name', notifications)
+            view_func=EditProjectName.as_view('edit_project_name', self.notifications)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/description',
             '/projects/<int:project_id>/description',
-            view_func=EditProjectDescription.as_view('edit_project_description', notifications)
+            view_func=EditProjectDescription.as_view('edit_project_description', self.notifications)
         )
         )
 
 
         # pipelines
         # pipelines
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/pipelines/fit',
             '/projects/<int:project_id>/pipelines/fit',
-            view_func=FitModel.as_view('fit_model', jobs, pipelines)
+            view_func=FitModel.as_view('fit_model', self.jobs, self.pipelines)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/projects/<int:project_id>/pipelines/predict',
             '/projects/<int:project_id>/pipelines/predict',
-            view_func=PredictModel.as_view('predict_model', notifications, jobs,
-                                           pipelines)
+            view_func=PredictModel.as_view('predict_model', self.notifications, self.jobs,
+                                           self.pipelines)
         )
         )
         self.app.add_url_rule(
         self.app.add_url_rule(
             '/data/<int:file_id>/predict',
             '/data/<int:file_id>/predict',
-            view_func=PredictFile.as_view('predict_file', notifications, jobs, pipelines)
+            view_func=PredictFile.as_view('predict_file', self.notifications,
+                                          self.jobs, self.pipelines)
         )
         )
 
 
     def run(self):
     def run(self):

+ 6 - 4
pycs/frontend/endpoints/pipelines/FitModel.py

@@ -23,14 +23,14 @@ class FitModel(View):
         self.pipelines = pipelines
         self.pipelines = pipelines
 
 
     def dispatch_request(self, project_id):
     def dispatch_request(self, project_id):
+        project = Project.get_or_404(project_id)
+
         # extract request data
         # extract request data
         data = request.get_json(force=True)
         data = request.get_json(force=True)
 
 
         if not data.get('fit', False):
         if not data.get('fit', False):
             abort(400, "fit flag is missing")
             abort(400, "fit flag is missing")
 
 
-        # find project
-        project = Project.get_or_404(project_id)
 
 
         # create job
         # create job
         try:
         try:
@@ -38,7 +38,9 @@ class FitModel(View):
                           'Model Interaction',
                           'Model Interaction',
                           f'{project.name} (fit model with new data)',
                           f'{project.name} (fit model with new data)',
                           f'{project.name}/model-interaction',
                           f'{project.name}/model-interaction',
-                          FitModel.load_and_fit, project.id)
+                          FitModel.load_and_fit,
+                          self.pipelines,
+                          project.id)
 
 
         except JobGroupBusyException:
         except JobGroupBusyException:
             return abort(400, "Model fitting already running")
             return abort(400, "Model fitting already running")
@@ -62,7 +64,7 @@ class FitModel(View):
 
 
         # load pipeline
         # load pipeline
         try:
         try:
-            pipeline = pipelines.load_from_root_folder(project, model.root_folder)
+            pipeline = pipelines.load_from_root_folder(project_id, model.root_folder)
             yield from pipeline.fit(storage)
             yield from pipeline.fit(storage)
         except TypeError:
         except TypeError:
             pass
             pass

+ 5 - 4
pycs/frontend/endpoints/pipelines/PredictFile.py

@@ -26,15 +26,15 @@ class PredictFile(View):
         self.pipelines = pipelines
         self.pipelines = pipelines
 
 
     def dispatch_request(self, file_id):
     def dispatch_request(self, file_id):
+        # find file
+        file = File.get_or_404(file_id)
+
         # extract request data
         # extract request data
         data = request.get_json(force=True)
         data = request.get_json(force=True)
 
 
         if not data.get('predict', False):
         if not data.get('predict', False):
             abort(400, "predict flag is missing")
             abort(400, "predict flag is missing")
 
 
-        # find file
-        file = File.get_or_404(file_id)
-
         # get project and model
         # get project and model
         project = file.project
         project = file.project
 
 
@@ -47,7 +47,8 @@ class PredictFile(View):
                           f'{project.name} (create predictions)',
                           f'{project.name} (create predictions)',
                           f'{project.id}/model-interaction',
                           f'{project.id}/model-interaction',
                           Predict.load_and_predict,
                           Predict.load_and_predict,
-                          self.pipelines, notifications, project.id, [file.id],
+                          self.pipelines, notifications,
+                          project.id, [file.id],
                           progress=Predict.progress)
                           progress=Predict.progress)
 
 
         except JobGroupBusyException:
         except JobGroupBusyException:

+ 3 - 4
pycs/frontend/endpoints/pipelines/PredictModel.py

@@ -31,6 +31,8 @@ class PredictModel(View):
         self.pipelines = pipelines
         self.pipelines = pipelines
 
 
     def dispatch_request(self, project_id):
     def dispatch_request(self, project_id):
+        project = Project.get_or_404(project_id)
+
         # extract request data
         # extract request data
         data = request.get_json(force=True)
         data = request.get_json(force=True)
 
 
@@ -42,9 +44,6 @@ class PredictModel(View):
         if predict not in ['all', 'new']:
         if predict not in ['all', 'new']:
             abort(400, "predict must be either 'all' or 'new'")
             abort(400, "predict must be either 'all' or 'new'")
 
 
-        # find project
-        project = Project.get_or_404(project_id)
-
         # create job
         # create job
         try:
         try:
             notifications = NotificationList(self.nm)
             notifications = NotificationList(self.nm)
@@ -102,7 +101,7 @@ class PredictModel(View):
         media_files = map(lambda f: MediaFile(f, notifications), files)
         media_files = map(lambda f: MediaFile(f, notifications), files)
         # load pipeline
         # load pipeline
         try:
         try:
-            pipeline = pipelines.load_from_root_folder(project, model_root)
+            pipeline = pipelines.load_from_root_folder(project_id, model_root)
 
 
             # iterate over media files
             # iterate over media files
             index = 0
             index = 0

+ 26 - 5
pycs/util/PipelineCache.py

@@ -24,13 +24,24 @@ class PipelineCache:
         self.__queue = Queue()
         self.__queue = Queue()
         self.__lock = Lock()
         self.__lock = Lock()
 
 
+    def start(self):
+        """ starts the main worker method """
         spawn_n(self.__run)
         spawn_n(self.__run)
 
 
-    def load_from_root_folder(self, project: Project, root_folder: str) -> Pipeline:
+    @property
+    def is_empty(self):
+        """ checks whether the pipeline cache is empty """
+        return len(self.__pipelines) == 0
+
+    def shutdown(self):
+        """ puts None in the queue to signal the worker to stop """
+        self.__queue.put(None)
+
+    def load_from_root_folder(self, project_id: int, root_folder: str) -> Pipeline:
         """
         """
         load configuration.json and create an instance from the included code object
         load configuration.json and create an instance from the included code object
 
 
-        :param project: associated project
+        :param project_id: associated project ID
         :param root_folder: path to model root folder
         :param root_folder: path to model root folder
         :return: Pipeline instance
         :return: Pipeline instance
         """
         """
@@ -50,7 +61,7 @@ class PipelineCache:
 
 
         # save instance to cache
         # save instance to cache
         with self.__lock:
         with self.__lock:
-            self.__pipelines[root_folder] = [1, pipeline, project]
+            self.__pipelines[root_folder] = [1, pipeline, project_id]
 
 
         # return
         # return
         return pipeline
         return pipeline
@@ -75,7 +86,11 @@ class PipelineCache:
     def __get(self):
     def __get(self):
         while True:
         while True:
             # get element from queue
             # get element from queue
-            root_folder, timestamp = self.__queue.get()
+            entry = self.__queue.get()
+            if entry is None:
+                # closing pipeline cache
+                return None
+            root_folder, timestamp = entry
 
 
             # sleep if needed
             # sleep if needed
             delay = int(timestamp + self.CLOSE_TIMER - time())
             delay = int(timestamp + self.CLOSE_TIMER - time())
@@ -101,7 +116,13 @@ class PipelineCache:
     def __run(self):
     def __run(self):
         while True:
         while True:
             # get pipeline
             # get pipeline
-            pipeline, project = tpool.execute(self.__get)
+            result = tpool.execute(self.__get)
+            if result is None:
+                return
+
+            pipeline, project_id = result
+
+            project = Project.query.get(project_id)
 
 
             # create job to close pipeline
             # create job to close pipeline
             self.__jobs.run(project,
             self.__jobs.run(project,

+ 51 - 11
tests/base.py

@@ -12,6 +12,7 @@ from pycs import settings
 from pycs.frontend.WebServer import WebServer
 from pycs.frontend.WebServer import WebServer
 from pycs.database.Model import Model
 from pycs.database.Model import Model
 from pycs.database.LabelProvider import LabelProvider
 from pycs.database.LabelProvider import LabelProvider
+from pycs.util.PipelineCache import PipelineCache
 
 
 server = None
 server = None
 
 
@@ -27,36 +28,75 @@ def pаtch_tpool_execute(test_func):
 
 
 class BaseTestCase(unittest.TestCase):
 class BaseTestCase(unittest.TestCase):
     _sleep_time = 0.2
     _sleep_time = 0.2
+    server = None
 
 
-    def setUp(self, discovery: bool = False):
+    @classmethod
+    def setUpClass(cls, discovery: bool = False):
         global server
         global server
+        PipelineCache.CLOSE_TIMER = 2
         app.config["TESTING"] = True
         app.config["TESTING"] = True
-        self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
         app.config["WTF_CSRF_ENABLED"] = False
         app.config["WTF_CSRF_ENABLED"] = False
         app.config["DEBUG"] = False
         app.config["DEBUG"] = False
         app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
         app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
 
 
+        if server is None:
+            server = WebServer(app, settings, discovery)
+
+        if cls.server is None:
+            cls.server = server
+        db.create_all()
+
+        # # run discovery modules manually
+        # Model.discover("models/")
+        # LabelProvider.discover("labels/")
+        cls.server.pipelines.start()
+
+
+    def wait_for_bg_jobs(self):
+
+        # wait for JobRunner jobs to finish
+        while True:
+            ready = True
+            for job in self.server.jobs.list():
+                if job.finished is None:
+                    print(f"{job} is not finished!")
+                    ready = False
+                    break
+
+            if ready:
+                break
+
+            eventlet.sleep(self._sleep_time)
+
+        # wait for PipelineCache to finish
+
+        while not self.server.pipelines.is_empty:
+            eventlet.sleep(self._sleep_time)
+
+    @classmethod
+    def tearDownClass(cls):
+        super().tearDownClass()
+        cls.server.pipelines.shutdown()
+
+
+
+    def setUp(self):
+        self.projects_dir = app.config["TEST_PROJECTS_DIR"] = "test_projects"
+
         db.create_all()
         db.create_all()
 
 
         self.client = app.test_client()
         self.client = app.test_client()
         self.context = app.test_request_context()
         self.context = app.test_request_context()
         self.context.push()
         self.context.push()
 
 
-        # init the server once
-        if server is None:
-            server = WebServer(app, settings, discovery)
-
-        elif discovery:
-            # run discovery modules manually
-            Model.discover("models/")
-            LabelProvider.discover("labels/")
-
         self.setupModels()
         self.setupModels()
 
 
     def setupModels(self):
     def setupModels(self):
         pass
         pass
 
 
     def tearDown(self):
     def tearDown(self):
+        self.wait_for_bg_jobs()
+
         self.context.pop()
         self.context.pop()
 
 
         if os.path.exists(self.projects_dir):
         if os.path.exists(self.projects_dir):

+ 2 - 1
tests/client/__init__.py

@@ -2,12 +2,13 @@ import tempfile
 
 
 from flask import url_for
 from flask import url_for
 
 
-from pycs.database.Model import Model
 from pycs.database.LabelProvider import LabelProvider
 from pycs.database.LabelProvider import LabelProvider
+from pycs.database.Model import Model
 
 
 from tests.base import BaseTestCase
 from tests.base import BaseTestCase
 from tests.client.file_tests import *
 from tests.client.file_tests import *
 from tests.client.label_tests import *
 from tests.client.label_tests import *
+from tests.client.pipeline_tests import *
 from tests.client.project_tests import *
 from tests.client.project_tests import *
 from tests.client.result_tests import *
 from tests.client.result_tests import *
 
 

+ 125 - 0
tests/client/pipeline_tests.py

@@ -0,0 +1,125 @@
+
+import eventlet
+import tempfile
+import uuid
+
+from flask import url_for
+from pathlib import Path
+
+from pycs.database.Model import Model
+from pycs.database.Project import Project
+
+from tests.base import BaseTestCase
+from tests.base import pаtch_tpool_execute
+
+
+
+class PipelineTests(BaseTestCase):
+
+    _sleep_time = .2
+
+    def setupModels(self):
+        super().setupModels()
+
+        Model.discover("tests/client/test_models")
+
+        self.model = Model.query.one()
+
+        self.project = Project.new(
+            name="test_project",
+            description="Project for a test case",
+            model=self.model,
+            root_folder="project_folder",
+            external_data=False,
+            data_folder="project_folder/data",
+        )
+        root = Path(self.project.root_folder)
+        data_root = Path(self.project.data_folder)
+
+        for folder in [data_root, root / "temp"]:
+            folder.mkdir(exist_ok=True, parents=True)
+
+        file_uuid = str(uuid.uuid1())
+        self.file, is_new = self.project.add_file(
+            uuid=file_uuid,
+            file_type="image",
+            name="name",
+            filename="image",
+            extension=".jpg",
+            size=32*1024,
+        )
+
+        self.assertTrue(is_new)
+        with open(self.file.absolute_path, "wb") as f:
+            f.write(b"some content")
+
+
+    def tearDown(self):
+        self.wait_for_bg_jobs()
+        self.project.delete()
+        super().tearDown()
+
+    def test_predict_file_busy(self):
+        url = url_for("predict_file", file_id=self.file.id)
+
+        self.post(url, json=dict(predict=True))
+        self.post(url, json=dict(predict=True), status_code=400)
+
+    def test_predict_file_errors(self):
+        self.post(url_for("predict_file", file_id=4242),
+            status_code=404)
+
+        url = url_for("predict_file", file_id=self.file.id)
+
+        for data in [None, dict(), dict(predict=False)]:
+            self.post(url, status_code=400, json=data)
+
+    def test_predict_file(self):
+        url = url_for("predict_file", file_id=self.file.id)
+        self.post(url, json=dict(predict=True))
+
+    def test_predict_model_errors(self):
+        self.post(url_for("predict_model", project_id=4242),
+            status_code=404)
+
+        url = url_for("predict_model", project_id=self.project.id)
+
+        for data in [None, dict(), dict(predict=False), dict(predict=True), dict(predict="not new or all")]:
+            self.post(url, status_code=400, json=data)
+
+    def test_predict_model_busy(self):
+        url = url_for("predict_model", project_id=self.project.id)
+        self.post(url, json=dict(predict="new"))
+        self.post(url, json=dict(predict="new"), status_code=400)
+
+
+    def test_predict_model_for_new(self):
+
+        url = url_for("predict_model", project_id=self.project.id)
+        self.post(url, json=dict(predict="new"))
+
+
+    def test_predict_model_for_all(self):
+        url = url_for("predict_model", project_id=self.project.id)
+        self.post(url, json=dict(predict="all"))
+
+
+
+    def test_model_fit_errors(self):
+
+        self.post(url_for("fit_model", project_id=4242),
+            status_code=404)
+
+        url = url_for("fit_model", project_id=self.project.id)
+
+        for data in [None, dict(), dict(fit=False)]:
+            self.post(url, status_code=400, json=data)
+
+    def test_model_fit_busy(self):
+        url = url_for("fit_model", project_id=self.project.id)
+        self.post(url, json=dict(fit=True))
+        self.post(url, json=dict(fit=True), status_code=400)
+
+    def test_model_fit(self):
+        url = url_for("fit_model", project_id=self.project.id)
+        self.post(url, json=dict(fit=True))

+ 11 - 0
tests/client/test_models/simple_model/configuration.json

@@ -0,0 +1,11 @@
+{
+  "name": "Test Model",
+  "description": "Simple test model",
+  "supports": [
+    "labeled-bounding-boxes"
+  ],
+  "code": {
+    "module": "model",
+    "class": "Model"
+  }
+}

+ 19 - 0
tests/client/test_models/simple_model/model.py

@@ -0,0 +1,19 @@
+
+from pycs.interfaces.Pipeline import Pipeline
+from pycs.interfaces.MediaFile import MediaFile
+from pycs.interfaces.MediaStorage import MediaStorage
+
+class Model(Pipeline):
+
+    def __init__(self, root_folder: str, configuration: dict):
+        super().__init__(root_folder, configuration)
+
+
+    def close(self):
+        print("Closing")
+
+    def execute(self, storage: MediaStorage, file: MediaFile):
+        print("executing model")
+
+    def fit(self, storage: MediaStorage):
+        print("fitting model")