Browse Source

added tests for listing projects, files, and collections

Dimitri Korsch 3 years ago
parent
commit
4cbcf8c0bc

+ 45 - 2
pycs/database/Collection.py

@@ -1,9 +1,12 @@
 from __future__ import annotations
 
+import os
+
 from typing import List
 
 from pycs import db
 from pycs.database.base import NamedBaseModel
+from pycs.database.util import commit_on_return
 
 class Collection(NamedBaseModel):
     """ DB Model for collections """
@@ -44,7 +47,7 @@ class Collection(NamedBaseModel):
     )
 
 
-    def get_files(self, offset: int = 0, limit: int = -1):
+    def get_files(self, *filters, offset: int = 0, limit: int = -1):
         """
         get an iterator of files associated with this project
 
@@ -56,7 +59,47 @@ class Collection(NamedBaseModel):
         # pylint: disable=import-outside-toplevel
         # pylint: disable=cyclic-import
         from pycs.database.File import File
-        return self.files.order_by(File.id).offset(offset).limit(limit)
+        return self.files.filter(*filters).order_by(File.id).offset(offset).limit(limit)
+
+    # pylint: disable=too-many-arguments
+    @commit_on_return
+    def add_file(self,
+                 uuid: str,
+                 file_type: str,
+                 name: str,
+                 extension: str,
+                 size: int,
+                 filename: str,
+                 frames: int = None,
+                 fps: float = None) -> T.Tuple["File", bool]:
+        """
+        add a file to this collection
+
+        :param uuid: unique identifier which is used for temporary files
+        :param file_type: file type (either image or video)
+        :param name: file name
+        :param extension: file extension
+        :param size: file size
+        :param filename: actual name in filesystem
+        :param frames: frame count
+        :param fps: frames per second
+        :return: file
+        """
+        path = os.path.join(self.project.data_folder, f"{filename}{extension}")
+        from pycs.database.File import File
+
+        file, is_new = File.get_or_create(
+            project_id=self.project_id, collection_id=self.id, path=path)
+
+        file.uuid = uuid
+        file.type = file_type
+        file.name = name
+        file.extension = extension
+        file.size = size
+        file.frames = frames
+        file.fps = fps
+
+        return file, is_new
 
 
     @staticmethod

+ 4 - 3
pycs/database/Project.py

@@ -293,7 +293,7 @@ class Project(NamedBaseModel):
         return file, is_new
 
 
-    def get_files(self, offset: int = 0, limit: int = -1) -> T.List[File]:
+    def get_files(self, *filters, offset: int = 0, limit: int = -1) -> T.List[File]:
         """
         get an iterator of files associated with this project
 
@@ -301,7 +301,8 @@ class Project(NamedBaseModel):
         :param limit: file limit
         :return: iterator of files
         """
-        return self.files.order_by(File.id).offset(offset).limit(limit)
+
+        return self.files.filter(*filters).order_by(File.id).offset(offset).limit(limit)
 
 
     def _files_without_results(self):
@@ -340,7 +341,7 @@ class Project(NamedBaseModel):
         :return: a query object
         """
         # pylint: disable=no-member
-        return self.get_files(offset, limit).filter(File.collection_id.is_(None))
+        return self.get_files(File.collection_id.is_(None), offset=offset, limit=limit)
 
     def files_without_collection(self, offset: int = 0, limit: int = -1) -> T.List[File]:
         """

+ 4 - 0
pycs/frontend/WebServer.py

@@ -206,6 +206,10 @@ class WebServer:
             '/projects/<int:project_id>/data',
             view_func=UploadFile.as_view('upload_file', notifications)
         )
+        self.app.add_url_rule(
+            '/projects/<int:project_id>/data',
+            view_func=ListProjectFiles.as_view('list_all_files')
+        )
         self.app.add_url_rule(
             '/projects/<int:project_id>/data/<int:start>/<int:length>',
             view_func=ListProjectFiles.as_view('list_files')

+ 2 - 1
pycs/frontend/endpoints/projects/CreateProject.py

@@ -88,7 +88,8 @@ class CreateProject(View):
                 commit=False)
 
             if not is_new:
-                abort(400, f"Could not copy model! Model in \"{model_folder}\" already exists!")
+                abort(400, # pragma: no cover
+                    f"Could not copy model! Model in \"{model_folder}\" already exists!")
 
             project = Project.new(name=name,
                                   description=description,

+ 8 - 4
pycs/frontend/endpoints/projects/ListProjectFiles.py

@@ -13,7 +13,11 @@ class ListProjectFiles(View):
     methods = ['GET']
 
 
-    def dispatch_request(self, project_id: int, start: int, length: int, collection_id: int = None):
+    def dispatch_request(self,
+                         project_id: int,
+                         start: int = 0,
+                         length: int = -1,
+                         collection_id: int = None):
         # find project
 
         project = Project.get_or_404(project_id)
@@ -22,7 +26,7 @@ class ListProjectFiles(View):
         if collection_id is not None:
             if collection_id == 0:
                 count = project.count_files_without_collection()
-                files = project.files_without_collection(start, length)
+                files = project.files_without_collection(offset=start, limit=length)
 
             else:
                 collection = project.collection(collection_id)
@@ -30,11 +34,11 @@ class ListProjectFiles(View):
                     abort(404)
 
                 count = collection.files.count()
-                files = collection.get_files(start, length).all()
+                files = collection.get_files(offset=start, limit=length).all()
 
         else:
             count = project.files.count()
-            files = project.get_files(start, length).all()
+            files = project.get_files(offset=start, limit=length).all()
 
         # return files
         return jsonify({

+ 197 - 4
tests/client/project_tests.py

@@ -1,3 +1,5 @@
+import uuid
+
 from flask import url_for
 
 from pycs.database.Collection import Collection
@@ -155,7 +157,7 @@ class ProjectListTests(_BaseProjectTests):
         self.assertTrue(response.is_json)
         content = response.json
 
-        self.assertTrue(10, len(content))
+        self.assertEqual(10, len(content))
 
         for entry in content:
             project = Project.query.get(entry["id"])
@@ -165,7 +167,7 @@ class ProjectListTests(_BaseProjectTests):
     def test_list_project_collections(self):
         project = Project.new(
             name="TestProject",
-            description="Project for a test case #",
+            description="Project for a test case",
             model=self.model,
             root_folder="project_folder",
             external_data=False,
@@ -180,7 +182,7 @@ class ProjectListTests(_BaseProjectTests):
                 name=f"Some collection {i}",
                 description=f"A description {i}",
                 position=i,
-                autoselect=i is 1
+                autoselect=i == 1
             )
         self.assertEqual(10, Collection.query.count())
 
@@ -190,9 +192,200 @@ class ProjectListTests(_BaseProjectTests):
         self.assertTrue(response.is_json)
         content = response.json
 
-        self.assertTrue(10, len(content))
+        self.assertEqual(10, len(content))
 
         for entry in content:
             collection = Collection.query.get(entry["id"])
             self.assertIsNotNone(collection)
             self.assertDictEqual(entry, collection.serialize())
+
+    def test_list_all_files(self):
+        project = Project.new(
+            name="TestProject",
+            description="Project for a test case",
+            model=self.model,
+            root_folder="project_folder",
+            external_data=False,
+            data_folder="project_folder/data",
+        )
+
+
+        self.assertEqual(0, File.query.count())
+        files = []
+        for i in range(1, 11):
+            file_uuid = str(uuid.uuid1())
+            file, is_new = project.add_file(
+                uuid=file_uuid,
+                file_type="image",
+                name=f"name{i}",
+                filename=f"image_{i:03d}",
+                extension=".jpg",
+                size=32*1024,
+            )
+            self.assertTrue(is_new)
+            files.append(file)
+
+        self.assertEqual(10, File.query.count())
+
+        response = self.get(url_for("list_all_files",
+            project_id=project.id))
+
+        self.assertTrue(response.is_json)
+        _content = response.json
+        count = _content["count"]
+        content = _content["files"]
+
+        self.assertEqual(10, count)
+        self.assertEqual(10, len(content))
+
+        for file, entry in zip(files, content):
+            self.assertDictEqual(entry, file.serialize())
+
+    def test_list_some_files(self):
+        project = Project.new(
+            name="TestProject",
+            description="Project for a test case",
+            model=self.model,
+            root_folder="project_folder",
+            external_data=False,
+            data_folder="project_folder/data",
+        )
+
+
+        self.assertEqual(0, File.query.count())
+        files = []
+        for i in range(1, 11):
+            file_uuid = str(uuid.uuid1())
+            file, is_new = project.add_file(
+                uuid=file_uuid,
+                file_type="image",
+                name=f"name{i}",
+                filename=f"image_{i:03d}",
+                extension=".jpg",
+                size=32*1024,
+            )
+            self.assertTrue(is_new)
+            files.append(file)
+
+        self.assertEqual(10, File.query.count())
+
+        for start, length in [(0, 5), (0, 15), (5, 3), (5, 8)]:
+            response = self.get(url_for("list_files",
+                project_id=project.id,
+                start=start, length=length))
+
+            self.assertTrue(response.is_json)
+            _content = response.json
+            count = _content["count"]
+            content = _content["files"]
+
+            self.assertEqual(len(files), count)
+            self.assertEqual(min(len(files), start+length) - start, len(content))
+
+            for file, entry in zip(files[start:start+length], content):
+                self.assertDictEqual(entry, file.serialize())
+
+
+    def test_list_collection_files_of_non_existing_collection(self):
+
+        project = Project.new(
+            name="TestProject",
+            description="Project for a test case",
+            model=self.model,
+            root_folder="project_folder",
+            external_data=False,
+            data_folder="project_folder/data",
+        )
+
+        url = url_for("list_collection_files",
+                      project_id=project.id, collection_id=42,
+                      start=0, length=30)
+        self.get(url, status_code=404)
+
+
+    def test_list_collection_files(self):
+        project = Project.new(
+            name="TestProject",
+            description="Project for a test case",
+            model=self.model,
+            root_folder="project_folder",
+            external_data=False,
+            data_folder="project_folder/data",
+        )
+
+        self.assertEqual(1, Project.query.count())
+
+        collections = {}
+        for i in range(1, 3):
+            collection, is_new = project.create_collection(
+                reference=f"collection_{i}",
+                name=f"Some collection {i}",
+                description=f"A description {i}",
+                position=i,
+                autoselect=i == 1
+            )
+
+            self.assertTrue(is_new)
+
+            collection_files = []
+
+            for j in range(1, 4):
+                file_uuid = str(uuid.uuid1())
+                file, is_new = collection.add_file(
+                    uuid=file_uuid,
+                    file_type="image",
+                    name=f"col_{i}_name{j}",
+                    filename=f"col_{i}_image_{j:03d}",
+                    extension=".jpg",
+                    size=32*1024,
+                )
+                self.assertTrue(is_new)
+                collection_files.append(file)
+
+            collections[collection.id] = collection_files
+
+        files = []
+        for j in range(1, 4):
+            file_uuid = str(uuid.uuid1())
+            file, is_new = project.add_file(
+                uuid=file_uuid,
+                file_type="image",
+                name=f"name{j}",
+                filename=f"image_{j:03d}",
+                extension=".jpg",
+                size=32*1024,
+            )
+            self.assertTrue(is_new)
+            files.append(file)
+
+        collections[0] = files
+
+        self.assertEqual(2, Collection.query.filter(Collection.project_id==project.id).count())
+
+        self.assertEqual(6, File.query.filter(
+            File.project_id == project.id,
+            File.collection_id != None,
+        ).count())
+
+        self.assertEqual(3, File.query.filter(
+            File.project_id == project.id,
+            File.collection_id == None,
+        ).count())
+
+
+        for collection_id, files in collections.items():
+            for start, length in [(0, 5), (0, 15), (1, 3), (1, 8)]:
+                response = self.get(url_for("list_collection_files",
+                    project_id=project.id, collection_id=collection_id,
+                    start=start, length=length))
+
+                self.assertTrue(response.is_json)
+                _content = response.json
+                count = _content["count"]
+                content = _content["files"]
+
+                self.assertEqual(len(files), count)
+                self.assertEqual(min(len(files), start+length) - start, len(content))
+
+                for file, entry in zip(files[start:start+length], content):
+                    self.assertDictEqual(entry, file.serialize())