Browse Source

minor fixes in the tests

Dimitri Korsch 4 years ago
parent
commit
ea8356b91b
3 changed files with 27 additions and 15 deletions
  1. 5 4
      models/fixed_model/Pipeline.py
  2. 1 1
      pycs/frontend/endpoints/data/UploadFile.py
  3. 21 10
      test/test_client.py

+ 5 - 4
models/fixed_model/Pipeline.py

@@ -2,6 +2,7 @@ from json import dump, load
 from os import path
 from time import sleep
 
+from pycs import app
 from pycs.interfaces.MediaFile import MediaFile
 from pycs.interfaces.MediaStorage import MediaStorage
 from pycs.interfaces.Pipeline import Pipeline as Interface
@@ -9,14 +10,14 @@ from pycs.interfaces.Pipeline import Pipeline as Interface
 
 class Pipeline(Interface):
     def __init__(self, root_folder, distribution):
-        print('fmv1 init')
+        app.logger.debug('fmv1 init')
         self.root_folder = root_folder
 
     def close(self):
-        print('fmv1 close')
+        app.logger.debug('fmv1 close')
 
     def execute(self, storage: MediaStorage, file: MediaFile):
-        print('fmv1 execute')
+        app.logger.debug('fmv1 execute')
 
         data_file = path.join(self.root_folder, 'data.json')
         if path.exists(data_file):
@@ -33,7 +34,7 @@ class Pipeline(Interface):
                     file.set_image_label(r['label'], r['frame'])
 
     def fit(self, storage: MediaStorage):
-        print('fmv1 fit')
+        app.logger.debug('fmv1 fit')
 
         for i in range(10):
             yield i / 10

+ 1 - 1
pycs/frontend/endpoints/data/UploadFile.py

@@ -70,7 +70,7 @@ class UploadFile(View):
         # return default success response
         return make_response()
 
-    def custom_stream_factory(self, total_content_length, filename, content_type,
+    def custom_stream_factory(self, total_content_length, content_type, filename=None,
                               content_length=None):
         """
         save some useful information and open a file handler to save the uploaded file to

+ 21 - 10
test/test_client.py

@@ -1,16 +1,27 @@
+import eventlet
 import io
 import time
-import eventlet
 
 from test.base import BaseTestCase
 
 from pycs.database.File import File
-from pycs.database.Result import Result
 from pycs.database.Label import Label
+from pycs.database.LabelProvider import LabelProvider
+from pycs.database.Model import Model
 from pycs.database.Project import Project
+from pycs.database.Result import Result
 
 class ClientTests(BaseTestCase):
 
+    def setUp(self):
+        super().setUp()
+        label_provider = LabelProvider.query.filter_by(name="Fixed Label Provider v1").one()
+        model = Model.query.filter_by(name="Fixed Base Model v1").one()
+
+        self.label_provider_id = label_provider.id
+        self.model_id = model.id
+
+
     def _do_request(self, request_func, *args,  status_code=200, **kwargs):
         response = request_func(*args, follow_redirects=True, **kwargs)
         self.assertEqual(response.status_code, status_code, response.get_data().decode())
@@ -45,8 +56,8 @@ class ClientTests(BaseTestCase):
             json=dict(
                 name="some name",
                 description="some description",
-                model=1,
-                label=2,
+                model=self.model_id,
+                label=self.label_provider_id,
                 external=None,
             )
         )
@@ -67,8 +78,8 @@ class ClientTests(BaseTestCase):
             json=dict(
                 name="some name",
                 description="some description",
-                model=1,
-                label=2,
+                model=self.model_id,
+                label=self.label_provider_id,
                 external=None,
             )
         )
@@ -107,8 +118,8 @@ class ClientTests(BaseTestCase):
             json=dict(
                 name="some name",
                 description="some description",
-                model=1,
-                label=2,
+                model=self.model_id,
+                label=self.label_provider_id,
                 external=None,
             )
         )
@@ -156,8 +167,8 @@ class ClientTests(BaseTestCase):
             json=dict(
                 name="some name",
                 description="some description",
-                model=1,
-                label=2,
+                model=self.model_id,
+                label=self.label_provider_id,
                 external=None,
             )
         )