浏览代码

Merge branch '127-all-database-functions-should-return-iterators' into 'master'

Resolve "all database functions should return iterators"

Closes #127

See merge request troebs/pycs!113
Eric Tröbs 3 年之前
父节点
当前提交
9205781d5f

+ 4 - 16
pycs/database/Collection.py

@@ -1,5 +1,5 @@
 from contextlib import closing
-from typing import List, Iterator
+from typing import Iterator
 
 from pycs.database.File import File
 
@@ -51,17 +51,7 @@ class Collection:
                            (self.project_id, self.identifier))
             return cursor.fetchone()[0]
 
-    def files(self, offset: int = 0, limit: int = -1) -> List[File]:
-        """
-        get a list of files associated with this collection
-
-        :param offset: file offset
-        :param limit: file limit
-        :return: list of files
-        """
-        return list(self.files_iter(offset, limit))
-
-    def files_iter(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
+    def files(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
         """
         get an iterator of files associated with this collection
 
@@ -76,7 +66,5 @@ class Collection:
                 ORDER BY id ASC LIMIT ? OFFSET ?
                 ''', (self.project_id, self.identifier, limit, offset))
 
-            return map(
-                lambda row: File(self.database, row),
-                cursor.fetchall()
-            )
+            for row in cursor:
+                yield File(self.database, row)

+ 13 - 19
pycs/database/Database.py

@@ -1,7 +1,7 @@
 import sqlite3
 from contextlib import closing
 from time import time
-from typing import Optional, List
+from typing import Optional, Iterator
 
 from pycs.database.Collection import Collection
 from pycs.database.File import File
@@ -155,18 +155,16 @@ class Database:
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.con.__exit__(exc_type, exc_val, exc_tb)
 
-    def models(self) -> List[Model]:
+    def models(self) -> Iterator[Model]:
         """
         get a list of all available models
 
-        :return: list of all available models
+        :return: iterator of models
         """
         with closing(self.con.cursor()) as cursor:
             cursor.execute('SELECT * FROM models')
-            return list(map(
-                lambda row: Model(self, row),
-                cursor.fetchall()
-            ))
+            for row in cursor:
+                yield Model(self, row)
 
     def model(self, identifier: int) -> Optional[Model]:
         """
@@ -184,18 +182,16 @@ class Database:
 
             return None
 
-    def label_providers(self) -> List[LabelProvider]:
+    def label_providers(self) -> Iterator[LabelProvider]:
         """
         get a list of all available label providers
 
-        :return: list of all available label providers
+        :return: iterator over label providers
         """
         with closing(self.con.cursor()) as cursor:
             cursor.execute('SELECT * FROM label_providers')
-            return list(map(
-                lambda row: LabelProvider(self, row),
-                cursor.fetchall()
-            ))
+            for row in cursor:
+                yield LabelProvider(self, row)
 
     def label_provider(self, identifier: int) -> Optional[LabelProvider]:
         """
@@ -213,18 +209,16 @@ class Database:
 
             return None
 
-    def projects(self) -> List[Project]:
+    def projects(self) -> Iterator[Project]:
         """
         get a list of all available projects
 
-        :return: list of all available projects
+        :return: iterator over projects
         """
         with closing(self.con.cursor()) as cursor:
             cursor.execute('SELECT * FROM projects')
-            return list(map(
-                lambda row: Project(self, row),
-                cursor.fetchall()
-            ))
+            for row in cursor:
+                yield Project(self, row)
 
     def project(self, identifier: int) -> Optional[Project]:
         """

+ 9 - 23
pycs/database/Project.py

@@ -233,17 +233,7 @@ class Project:
             cursor.execute('SELECT COUNT(*) FROM files WHERE project = ?', [self.identifier])
             return cursor.fetchone()[0]
 
-    def files(self, offset: int = 0, limit: int = -1) -> List[File]:
-        """
-        get a list of files associated with this project
-
-        :param offset: file offset
-        :param limit: file limit
-        :return: list of files
-        """
-        return list(self.files_iter(offset, limit))
-
-    def files_iter(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
+    def files(self, offset: int = 0, limit: int = -1) -> Iterator[File]:
         """
         get an iterator of files associated with this project
 
@@ -275,9 +265,9 @@ class Project:
             ''', [self.identifier])
             return cursor.fetchone()[0]
 
-    def files_without_results(self) -> List[File]:
+    def files_without_results(self) -> Iterator[File]:
         """
-        get a list of files without associated results
+        get an iterator of files without associated results
 
         :return: list of files
         """
@@ -290,10 +280,8 @@ class Project:
                 ORDER BY id ASC
             ''', [self.identifier])
 
-            return list(map(
-                lambda row: File(self.database, row),
-                cursor.fetchall()
-            ))
+            for row in cursor:
+                yield File(self.database, row)
 
     def count_files_without_collection(self) -> int:
         """
@@ -306,9 +294,9 @@ class Project:
                            [self.identifier])
             return cursor.fetchone()[0]
 
-    def files_without_collection(self, offset=0, limit=-1) -> List[File]:
+    def files_without_collection(self, offset=0, limit=-1) -> Iterator[File]:
         """
-        get a list of files without not associated with any collection
+        get an iterator of files without not associated with any collection
 
         :return: list of files
         """
@@ -320,10 +308,8 @@ class Project:
                 LIMIT ? OFFSET ?
             ''', (self.identifier, limit, offset))
 
-            return list(map(
-                lambda row: File(self.database, row),
-                cursor.fetchall()
-            ))
+            for row in cursor:
+                yield File(self.database, row)
 
     def file(self, identifier) -> Optional[File]:
         """

+ 2 - 1
pycs/frontend/endpoints/ListLabelProviders.py

@@ -16,4 +16,5 @@ class ListLabelProviders(View):
         self.db = db
 
     def dispatch_request(self):
-        return jsonify(self.db.label_providers())
+        label_providers = list(self.db.label_providers())
+        return jsonify(label_providers)

+ 2 - 1
pycs/frontend/endpoints/ListModels.py

@@ -16,4 +16,5 @@ class ListModels(View):
         self.db = db
 
     def dispatch_request(self):
-        return jsonify(self.db.models())
+        models = list(self.db.models())
+        return jsonify(models)

+ 2 - 1
pycs/frontend/endpoints/ListProjects.py

@@ -16,4 +16,5 @@ class ListProjects(View):
         self.db = db
 
     def dispatch_request(self):
-        return jsonify(self.db.projects())
+        projects = list(self.db.projects())
+        return jsonify(projects)

+ 3 - 3
pycs/frontend/endpoints/projects/ListFiles.py

@@ -25,17 +25,17 @@ class ListFiles(View):
         if collection_id is not None:
             if collection_id == 0:
                 count = project.count_files_without_collection()
-                files = project.files_without_collection(start, length)
+                files = list(project.files_without_collection(start, length))
             else:
                 collection = project.collection(collection_id)
                 if collection is None:
                     return abort(404)
 
                 count = collection.count_files()
-                files = collection.files(start, length)
+                files = list(collection.files(start, length))
         else:
             count = project.count_files()
-            files = project.files(start, length)
+            files = list(project.files(start, length))
 
         # return files
         return jsonify({

+ 13 - 13
test/test_database.py

@@ -30,8 +30,8 @@ class TestDatabase(unittest.TestCase):
                 ''')
 
                 # projects
-                models = self.database.models()
-                label_providers = self.database.label_providers()
+                models = list(self.database.models())
+                label_providers = list(self.database.label_providers())
 
                 for i in range(3):
                     self.database.create_project(
@@ -45,7 +45,7 @@ class TestDatabase(unittest.TestCase):
         self.database.close()
 
     def test_models(self):
-        models = self.database.models()
+        models = list(self.database.models())
 
         # test length
         self.assertEqual(len(models), 3)
@@ -69,7 +69,7 @@ class TestDatabase(unittest.TestCase):
         self.assertEqual(copy.supports, ['labeled-image', 'fit'])
 
     def test_label_providers(self):
-        label_providers = self.database.label_providers()
+        label_providers = list(self.database.label_providers())
 
         # test length
         self.assertEqual(len(label_providers), 2)
@@ -81,9 +81,9 @@ class TestDatabase(unittest.TestCase):
             self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
 
     def test_projects(self):
-        models = self.database.models()
-        label_providers = self.database.label_providers()
-        projects = self.database.projects()
+        models = list(self.database.models())
+        label_providers = list(self.database.label_providers())
+        projects = list(self.database.projects())
 
         # create projects
         for i in range(3):
@@ -104,23 +104,23 @@ class TestDatabase(unittest.TestCase):
             self.assertEqual(project.data_folder, f'datadir{i + 1}')
 
         # get projects
-        self.assertEqual(len(self.database.projects()), 3)
+        self.assertEqual(len(list(self.database.projects())), 3)
 
         # remove a project
-        self.database.projects()[0].remove()
-        projects = self.database.projects()
+        list(self.database.projects())[0].remove()
+        projects = list(self.database.projects())
 
         self.assertEqual(len(projects), 2)
         self.assertEqual(projects[0].name, 'Project 2')
 
         # set properties
-        project = self.database.projects()[0]
+        project = list(self.database.projects())[0]
 
         project.set_name('Project 0')
-        self.assertEqual(self.database.projects()[0].name, 'Project 0')
+        self.assertEqual(list(self.database.projects())[0].name, 'Project 0')
 
         project.set_description('Description 0')
-        self.assertEqual(self.database.projects()[0].description, 'Description 0')
+        self.assertEqual(list(self.database.projects())[0].description, 'Description 0')
 
 
 if __name__ == '__main__':