|
@@ -1,11 +1,12 @@
|
|
|
import unittest
|
|
|
|
|
|
-from pycs.database.Database import Database
|
|
|
+from pycs import db
|
|
|
from pycs.database.File import File
|
|
|
from pycs.database.Label import Label
|
|
|
-from pycs.database.Result import Result
|
|
|
-from pycs.database.Model import Model
|
|
|
from pycs.database.LabelProvider import LabelProvider
|
|
|
+from pycs.database.Model import Model
|
|
|
+from pycs.database.Project import Project
|
|
|
+from pycs.database.Result import Result
|
|
|
|
|
|
from test.base import BaseTestCase
|
|
|
|
|
@@ -15,7 +16,7 @@ class DatabaseTests(BaseTestCase):
|
|
|
super().setUp(discovery=False)
|
|
|
|
|
|
# insert default models and label_providers
|
|
|
- with self.database:
|
|
|
+ with db.session.begin_nested():
|
|
|
|
|
|
for i, supports in enumerate([["labeled-image", "fit"], ["labeled-bounding-boxes"], ["labeled-bounding-boxes"]], 1):
|
|
|
|
|
@@ -39,11 +40,11 @@ class DatabaseTests(BaseTestCase):
|
|
|
)
|
|
|
|
|
|
# projects
|
|
|
- models = list(self.database.models())
|
|
|
- label_providers = list(self.database.label_providers())
|
|
|
+ models = Model.query.all()
|
|
|
+ label_providers = LabelProvider.query.all()
|
|
|
|
|
|
for i, model in enumerate(models, 1):
|
|
|
- self.database.create_project(
|
|
|
+ Project.new(
|
|
|
name=f'Project {i}',
|
|
|
description=f'Project Description {i}',
|
|
|
model=model,
|
|
@@ -54,7 +55,7 @@ class DatabaseTests(BaseTestCase):
|
|
|
)
|
|
|
|
|
|
def test_models(self):
|
|
|
- models = list(self.database.models())
|
|
|
+ models = Model.query.all()
|
|
|
|
|
|
# test length
|
|
|
self.assertEqual(len(models), 3)
|
|
@@ -78,7 +79,7 @@ class DatabaseTests(BaseTestCase):
|
|
|
self.assertEqual(copy.supports, ['labeled-image', 'fit'])
|
|
|
|
|
|
def test_label_providers(self):
|
|
|
- label_providers = list(self.database.label_providers())
|
|
|
+ label_providers = LabelProvider.query.all()
|
|
|
|
|
|
# test length
|
|
|
self.assertEqual(len(label_providers), 2)
|
|
@@ -90,9 +91,9 @@ class DatabaseTests(BaseTestCase):
|
|
|
self.assertEqual(label_providers[i].root_folder, f'labeldir{i + 1}')
|
|
|
|
|
|
def test_projects(self):
|
|
|
- models = list(self.database.models())
|
|
|
- label_providers = list(self.database.label_providers())
|
|
|
- projects = list(self.database.projects())
|
|
|
+ models = Model.query.all()
|
|
|
+ label_providers = LabelProvider.query.all()
|
|
|
+ projects = Project.query.all()
|
|
|
|
|
|
# create projects
|
|
|
for i in range(3):
|
|
@@ -113,28 +114,27 @@ class DatabaseTests(BaseTestCase):
|
|
|
self.assertEqual(project.data_folder, f'datadir{i + 1}')
|
|
|
|
|
|
# get projects
|
|
|
- self.assertEqual(len(list(self.database.projects())), 3)
|
|
|
+ self.assertEqual(Project.query.count(), 3)
|
|
|
|
|
|
# remove a project
|
|
|
- list(self.database.projects())[0].remove()
|
|
|
- projects = list(self.database.projects())
|
|
|
+ Project.query.first().remove()
|
|
|
|
|
|
- self.assertEqual(len(projects), 2)
|
|
|
- self.assertEqual(projects[0].name, 'Project 2')
|
|
|
+ self.assertEqual(Project.query.count(), 2)
|
|
|
+ self.assertEqual(Project.query.first().name, 'Project 2')
|
|
|
|
|
|
# set properties
|
|
|
- project = list(self.database.projects())[0]
|
|
|
+ project = Project.query.first()
|
|
|
|
|
|
project.set_name('Project 0')
|
|
|
- self.assertEqual(list(self.database.projects())[0].name, 'Project 0')
|
|
|
+ self.assertEqual(Project.query.first().name, 'Project 0')
|
|
|
|
|
|
project.set_description('Description 0')
|
|
|
- self.assertEqual(list(self.database.projects())[0].description, 'Description 0')
|
|
|
+ self.assertEqual(Project.query.first().description, 'Description 0')
|
|
|
|
|
|
|
|
|
def test_no_files_after_project_deletion(self):
|
|
|
|
|
|
- project = self.database.project(1)
|
|
|
+ project = Project.query.get(1)
|
|
|
for i in range(5):
|
|
|
file, is_new = project.add_file(
|
|
|
uuid=f"some_string{i}",
|
|
@@ -151,13 +151,13 @@ class DatabaseTests(BaseTestCase):
|
|
|
self.assertEqual(5, File.query.filter_by(project_id=project.id).count())
|
|
|
|
|
|
project.remove()
|
|
|
- self.assertIsNone(self.database.project(1))
|
|
|
+ self.assertIsNone(Project.query.get(1))
|
|
|
self.assertEqual(0, File.query.filter_by(project_id=project.id).count())
|
|
|
|
|
|
def test_no_labels_after_project_deletion(self):
|
|
|
|
|
|
self.assertEqual(0, Label.query.count())
|
|
|
- project = self.database.project(1)
|
|
|
+ project = Project.query.get(1)
|
|
|
for i in range(5):
|
|
|
label, is_new = project.create_label(
|
|
|
name=f"label{i}",
|
|
@@ -171,13 +171,13 @@ class DatabaseTests(BaseTestCase):
|
|
|
|
|
|
project.remove()
|
|
|
|
|
|
- self.assertIsNone(self.database.project(1))
|
|
|
+ self.assertIsNone(Project.query.get(1))
|
|
|
self.assertEqual(0, Label.query.count())
|
|
|
|
|
|
|
|
|
def test_no_results_after_file_deletion(self):
|
|
|
|
|
|
- project = self.database.project(1)
|
|
|
+ project = Project.query.get(1)
|
|
|
self.assertIsNotNone(project)
|
|
|
|
|
|
file, is_new = project.add_file(
|