Browse Source

minor fixes in the tests

Dimitri Korsch 4 years ago
parent
commit
ea8356b91b
3 changed files with 27 additions and 15 deletions
  1. 5 4
      models/fixed_model/Pipeline.py
  2. 1 1
      pycs/frontend/endpoints/data/UploadFile.py
  3. 21 10
      test/test_client.py

+ 5 - 4
models/fixed_model/Pipeline.py

@@ -2,6 +2,7 @@ from json import dump, load
 from os import path
 from os import path
 from time import sleep
 from time import sleep
 
 
+from pycs import app
 from pycs.interfaces.MediaFile import MediaFile
 from pycs.interfaces.MediaFile import MediaFile
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.Pipeline import Pipeline as Interface
 from pycs.interfaces.Pipeline import Pipeline as Interface
@@ -9,14 +10,14 @@ from pycs.interfaces.Pipeline import Pipeline as Interface
 
 
 class Pipeline(Interface):
 class Pipeline(Interface):
     def __init__(self, root_folder, distribution):
     def __init__(self, root_folder, distribution):
-        print('fmv1 init')
+        app.logger.debug('fmv1 init')
         self.root_folder = root_folder
         self.root_folder = root_folder
 
 
     def close(self):
     def close(self):
-        print('fmv1 close')
+        app.logger.debug('fmv1 close')
 
 
     def execute(self, storage: MediaStorage, file: MediaFile):
     def execute(self, storage: MediaStorage, file: MediaFile):
-        print('fmv1 execute')
+        app.logger.debug('fmv1 execute')
 
 
         data_file = path.join(self.root_folder, 'data.json')
         data_file = path.join(self.root_folder, 'data.json')
         if path.exists(data_file):
         if path.exists(data_file):
@@ -33,7 +34,7 @@ class Pipeline(Interface):
                     file.set_image_label(r['label'], r['frame'])
                     file.set_image_label(r['label'], r['frame'])
 
 
     def fit(self, storage: MediaStorage):
     def fit(self, storage: MediaStorage):
-        print('fmv1 fit')
+        app.logger.debug('fmv1 fit')
 
 
         for i in range(10):
         for i in range(10):
             yield i / 10
             yield i / 10

+ 1 - 1
pycs/frontend/endpoints/data/UploadFile.py

@@ -70,7 +70,7 @@ class UploadFile(View):
         # return default success response
         # return default success response
         return make_response()
         return make_response()
 
 
-    def custom_stream_factory(self, total_content_length, filename, content_type,
+    def custom_stream_factory(self, total_content_length, content_type, filename=None,
                               content_length=None):
                               content_length=None):
         """
         """
         save some useful information and open a file handler to save the uploaded file to
         save some useful information and open a file handler to save the uploaded file to

+ 21 - 10
test/test_client.py

@@ -1,16 +1,27 @@
+import eventlet
 import io
 import io
 import time
 import time
-import eventlet
 
 
 from test.base import BaseTestCase
 from test.base import BaseTestCase
 
 
 from pycs.database.File import File
 from pycs.database.File import File
-from pycs.database.Result import Result
 from pycs.database.Label import Label
 from pycs.database.Label import Label
+from pycs.database.LabelProvider import LabelProvider
+from pycs.database.Model import Model
 from pycs.database.Project import Project
 from pycs.database.Project import Project
+from pycs.database.Result import Result
 
 
 class ClientTests(BaseTestCase):
 class ClientTests(BaseTestCase):
 
 
+    def setUp(self):
+        super().setUp()
+        label_provider = LabelProvider.query.filter_by(name="Fixed Label Provider v1").one()
+        model = Model.query.filter_by(name="Fixed Base Model v1").one()
+
+        self.label_provider_id = label_provider.id
+        self.model_id = model.id
+
+
     def _do_request(self, request_func, *args,  status_code=200, **kwargs):
     def _do_request(self, request_func, *args,  status_code=200, **kwargs):
         response = request_func(*args, follow_redirects=True, **kwargs)
         response = request_func(*args, follow_redirects=True, **kwargs)
         self.assertEqual(response.status_code, status_code, response.get_data().decode())
         self.assertEqual(response.status_code, status_code, response.get_data().decode())
@@ -45,8 +56,8 @@ class ClientTests(BaseTestCase):
             json=dict(
             json=dict(
                 name="some name",
                 name="some name",
                 description="some description",
                 description="some description",
-                model=1,
-                label=2,
+                model=self.model_id,
+                label=self.label_provider_id,
                 external=None,
                 external=None,
             )
             )
         )
         )
@@ -67,8 +78,8 @@ class ClientTests(BaseTestCase):
             json=dict(
             json=dict(
                 name="some name",
                 name="some name",
                 description="some description",
                 description="some description",
-                model=1,
-                label=2,
+                model=self.model_id,
+                label=self.label_provider_id,
                 external=None,
                 external=None,
             )
             )
         )
         )
@@ -107,8 +118,8 @@ class ClientTests(BaseTestCase):
             json=dict(
             json=dict(
                 name="some name",
                 name="some name",
                 description="some description",
                 description="some description",
-                model=1,
-                label=2,
+                model=self.model_id,
+                label=self.label_provider_id,
                 external=None,
                 external=None,
             )
             )
         )
         )
@@ -156,8 +167,8 @@ class ClientTests(BaseTestCase):
             json=dict(
             json=dict(
                 name="some name",
                 name="some name",
                 description="some description",
                 description="some description",
-                model=1,
-                label=2,
+                model=self.model_id,
+                label=self.label_provider_id,
                 external=None,
                 external=None,
             )
             )
         )
         )