瀏覽代碼

added tests for file upload and removal. fixed some bugs

Dimitri Korsch 3 年之前
父節點
當前提交
d272f7c26d

+ 1 - 1
migrations/versions/b03df3e31b8d_.py

@@ -83,7 +83,7 @@ def upgrade():
     sa.Column('uuid', sa.String(), nullable=False),
     sa.Column('extension', sa.String(), nullable=False),
     sa.Column('type', sa.String(), nullable=False),
-    sa.Column('size', sa.String(), nullable=False),
+    sa.Column('size', sa.Integer(), nullable=False),
     sa.Column('created', sa.DateTime(), nullable=False),
     sa.Column('path', sa.String(), nullable=False),
     sa.Column('frames', sa.Integer(), nullable=True),

+ 3 - 2
pycs/database/File.py

@@ -25,7 +25,7 @@ class File(NamedBaseModel):
 
     type = db.Column(db.String, nullable=False)
 
-    size = db.Column(db.String, nullable=False)
+    size = db.Column(db.Integer, nullable=False)
 
     created = db.Column(db.DateTime, default=datetime.utcnow,
         index=True, nullable=False)
@@ -94,13 +94,14 @@ class File(NamedBaseModel):
         """
 
         # pylint: disable=unexpected-keyword-arg
-        super().delete(commit=commit)
+        dump = super().delete(commit=commit)
 
         if commit:
             os.remove(self.path)
 
         # TODO: remove temp files
         warnings.warn("Temporary files may still exist!")
+        return dump
 
     @commit_on_return
     def set_collection(self, collection_id: T.Optional[int]):

+ 1 - 1
pycs/database/base.py

@@ -44,8 +44,8 @@ class BaseModel(db.Model, SerializerMixin):
 
         :return: serialized self
         """
-        db.session.delete(self)
         dump = self.serialize()
+        db.session.delete(self)
 
         return dump
 

+ 16 - 16
pycs/frontend/endpoints/data/UploadFile.py

@@ -23,10 +23,10 @@ class UploadFile(View):
         self.nm = nm
 
         self.data_folder = None
-        self.file_uuid = None
-        self.file_name = None
-        self.file_extension = None
-        self.file_size = None
+        self.uuid = None
+        self.name = None
+        self.extension = None
+        self.size = None
 
     def dispatch_request(self, project_id: int):
         # find project
@@ -38,7 +38,7 @@ class UploadFile(View):
 
         # get upload path and id
         self.data_folder = project.data_folder
-        self.file_uuid = str(uuid.uuid1())
+        self.uuid = str(uuid.uuid1())
 
         # parse upload data
         _, _, files = tpool.execute(formparser.parse_form_data,
@@ -53,18 +53,18 @@ class UploadFile(View):
         try:
             ftype, frames, fps = tpool.execute(file_info,
                                                self.data_folder,
-                                               self.file_uuid,
-                                               self.file_extension)
+                                               self.uuid,
+                                               self.extension)
         except ValueError as exception:
             return abort(400, str(exception))
 
         file, _ = project.add_file(
-            uuid=self.file_uuid,
+            uuid=self.uuid,
             file_type=ftype,
-            name=self.file_name,
-            extension=self.file_extension,
-            size=self.file_size,
-            filename=self.file_uuid,
+            name=self.name,
+            extension=self.extension,
+            size=self.size,
+            filename=self.uuid,
             frames=frames,
             fps=fps)
 
@@ -87,14 +87,14 @@ class UploadFile(View):
         """
         # pylint: disable=unused-argument
         # set relevant properties
-        self.file_name, self.file_extension = os.path.splitext(filename)
+        self.name, self.extension = os.path.splitext(filename)
 
         if content_length is not None and content_length > 0:
-            self.file_size = content_length
+            self.size = content_length
         else:
-            self.file_size = total_content_length
+            self.size = total_content_length
 
         # open file handler
-        file_path = os.path.join(self.data_folder, f'{self.file_uuid}{self.file_extension}')
+        file_path = os.path.join(self.data_folder, f'{self.uuid}{self.extension}')
         #pylint: disable=consider-using-with
         return open(file_path, 'wb')

+ 20 - 4
tests/base.py

@@ -1,9 +1,11 @@
-
 import os
 import shutil
 import unittest
+import eventlet
 import typing as T
 
+from unittest import mock
+
 from pycs import app
 from pycs import db
 from pycs import settings
@@ -13,7 +15,18 @@ from pycs.database.LabelProvider import LabelProvider
 
 server = None
 
+def pаtch_tpool_execute(test_func):
+
+    def call_func(func, *args, **kwargs):
+        return func(*args, **kwargs)
+
+    decorator = mock.patch("eventlet.tpool.execute",
+        side_effect=call_func)
+
+    return decorator(test_func)
+
 class BaseTestCase(unittest.TestCase):
+    _sleep_time = 0.2
 
     def setUp(self, discovery: bool = False):
         global server
@@ -61,9 +74,9 @@ class BaseTestCase(unittest.TestCase):
              url: str,
              *,
              status_code: int = 200,
-             content_type: T.Optional[str] = None,
              data: T.Optional[dict] = None,
-             json: T.Optional[dict] = None):
+             json: T.Optional[dict] = None,
+             **kwargs):
 
         return self._do_request(
             self.client.post,
@@ -71,7 +84,7 @@ class BaseTestCase(unittest.TestCase):
             status_code=status_code,
             json=json,
             data=data,
-            content_type=content_type,
+            **kwargs
         )
 
     def get(self,
@@ -88,3 +101,6 @@ class BaseTestCase(unittest.TestCase):
             json=json,
             data=data,
         )
+
+    def wait_for_coroutines(self):
+        eventlet.sleep(self._sleep_time)

+ 2 - 0
tests/client/__init__.py

@@ -3,8 +3,10 @@ import tempfile
 from flask import url_for
 
 from tests.base import BaseTestCase
+from tests.client.file_tests import *
 from tests.client.label_tests import *
 from tests.client.project_tests import *
+from tests.client.result_tests import *
 
 
 class FolderInformationTest(BaseTestCase):

+ 127 - 0
tests/client/file_tests.py

@@ -0,0 +1,127 @@
+import io
+import os
+import uuid
+
+from flask import url_for
+from pathlib import Path
+from pycs.database.File import File
+from tests.base import pаtch_tpool_execute
+from tests.client.label_tests import _BaseLabelTests
+
+
+class _BaseFileTests(_BaseLabelTests):
+
+    def setupModels(self):
+        super().setupModels()
+        root = Path(self.project.root_folder)
+        data_root = Path(self.project.data_folder)
+
+        for folder in [data_root, root / "temp"]:
+            folder.mkdir(exist_ok=True, parents=True)
+
+
+class FileCreationTests(_BaseFileTests):
+
+    @pаtch_tpool_execute
+    def test_file_upload_project_with_external_data(self, mocked_execute=None):
+
+        file_content = b"some content+1"
+        url = url_for("upload_file", project_id=self.project.id)
+
+        self.assertEqual(0, File.query.count())
+
+        self.project.external_data = True
+        self.project.commit()
+
+        self.post(url,
+            data=dict(file=(io.BytesIO(file_content), "image.jpg")),
+            content_type="multipart/form-data",
+            status_code=400,
+        )
+
+        self.assertEqual(0, File.query.count())
+
+    @pаtch_tpool_execute
+    def test_file_upload(self, mocked_execute=None):
+
+        url = url_for("upload_file", project_id=4242)
+        self.post(url, data=dict(), status_code=404)
+
+        file_content = b"some content+1"
+        url = url_for("upload_file", project_id=self.project.id)
+
+        self.assertEqual(0, File.query.count())
+
+        self.post(url, data=dict(),
+            status_code=400)
+        self.assertEqual(0, File.query.count())
+
+        self.post(url,
+            data=dict(file=(io.BytesIO(file_content), "image.jpg")),
+            content_type="multipart/form-data",
+        )
+
+        self.assertEqual(1, File.query.count())
+
+        # this does not work, if we do not set the CONTENT_LENGTH by ourself
+        # file = File.query.first()
+        # self.assertEqual(len(file_content), file.size)
+
+
+class FileDeletionTests(_BaseFileTests):
+
+    def test_file_removal(self):
+
+        file_uuid = str(uuid.uuid1())
+        file, is_new = self.project.add_file(
+            uuid=file_uuid,
+            file_type="image",
+            name=f"name",
+            filename=f"image",
+            extension=".jpg",
+            size=32*1024,
+        )
+
+        self.assertTrue(is_new)
+
+        self.assertEqual(1, self.project.files.count())
+
+        with open(file.absolute_path, "w"):
+            pass
+
+        self.assertTrue(os.path.exists(file.absolute_path))
+
+        url = url_for("remove_file", file_id=file.id)
+        self.post(url, json=dict(), status_code=400)
+        self.post(url, json=dict(remove=False), status_code=400)
+        self.post(url, json=dict(remove=True))
+        self.assertEqual(0, self.project.files.count())
+        self.assertFalse(os.path.exists(file.absolute_path))
+
+        url = url_for("remove_file", file_id=4242)
+        self.post(url, json=dict(remove=True), status_code=404)
+
+    def test_file_removal_from_project_with_external_data(self):
+
+        file_uuid = str(uuid.uuid1())
+        file, is_new = self.project.add_file(
+            uuid=file_uuid,
+            file_type="image",
+            name=f"name",
+            filename=f"image",
+            extension=".jpg",
+            size=32*1024,
+        )
+
+        self.assertTrue(is_new)
+
+        with open(file.absolute_path, "w"):
+            pass
+
+        self.project.external_data = True
+        self.assertTrue(os.path.exists(file.absolute_path))
+        url = url_for("remove_file", file_id=file.id)
+
+        self.assertEqual(1, self.project.files.count())
+        self.post(url, json=dict(remove=True), status_code=400)
+        self.assertEqual(1, self.project.files.count())

+ 8 - 12
tests/client/label_tests.py

@@ -6,31 +6,27 @@ from pycs.database.Label import Label
 from pycs.database.Model import Model
 from pycs.database.Project import Project
 
-from tests.base import BaseTestCase
+from tests.client.project_tests import _BaseProjectTests
 
 
-class _BaseLabelTests(BaseTestCase):
+class _BaseLabelTests(_BaseProjectTests):
 
     def setupModels(self):
-
-        model = Model.new(
-            commit=False,
-            name="TestModel",
-            description="Model for a test case",
-            root_folder="model_folder",
-        )
-        model.supports = ["labeled-image"]
-        model.flush()
+        super().setupModels()
 
         self.project = Project.new(
             name="test_project",
             description="Project for a test case",
-            model=model,
+            model=self.model,
             root_folder="project_folder",
             external_data=False,
             data_folder="project_folder/data",
         )
 
+    def tearDown(self):
+        self.project.delete()
+        super().tearDown()
+
 class LabelCreationTests(_BaseLabelTests):
 
     def setUp(self):

+ 3 - 0
tests/client/project_tests.py

@@ -28,6 +28,9 @@ class _BaseProjectTests(BaseTestCase):
 
         self.model = model
 
+    def tearDown(self):
+        self.model.delete()
+        super().tearDown()
 
 class ProjectCreationTests(_BaseProjectTests):
 

+ 18 - 0
tests/client/result_tests.py

@@ -0,0 +1,18 @@
+from flask import url_for
+
+from pycs.database.Model import Model
+from pycs.database.Project import Project
+
+from tests.client.label_tests import _BaseLabelTests
+
+
+class _BaseResultTests(_BaseLabelTests):
+
+    def setupModels(self):
+        super().setupModels()
+
+
+
+
+class ResultCreationTests(_BaseResultTests):
+    pass