123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519 |
- import uuid
- from flask import url_for
- from pycs.database.Collection import Collection
- from pycs.database.File import File
- from pycs.database.Label import Label
- from pycs.database.Model import Model
- from pycs.database.Project import Project
- from pycs.database.Result import Result
- from pycs.interfaces.MediaFile import MediaFile
- from tests.base import BaseTestCase
- class _BaseProjectTests(BaseTestCase):
- def setupModels(self):
- model = Model.new(
- commit=False,
- name="TestModel",
- description="Model for a test case",
- root_folder="models/fixed_model",
- )
- model.supports = ["labeled-image"]
- model.flush()
- self.model = model
- def tearDown(self):
- self.model.delete()
- super().tearDown()
- class ProjectCreationTests(_BaseProjectTests):
- def setUp(self):
- super().setUp()
- self.url = url_for("create_project")
- def test_project_creation_without_name(self):
- self.assertEqual(0, Project.query.count())
- self.post(self.url, json=dict(
- # name="Some Project",
- description="Some description",
- model=self.model.id,
- label=None,
- external=None,
- ),
- status_code=400,
- )
- self.assertEqual(0, Project.query.count())
- def test_project_creation_without_description(self):
- self.assertEqual(0, Project.query.count())
- self.post(self.url, json=dict(
- name="Some Project",
- # description="Some description",
- model=self.model.id,
- label=None,
- external=None,
- ),
- status_code=400,
- )
- self.assertEqual(0, Project.query.count())
- def test_project_creation(self):
- self.assertEqual(0, Project.query.count())
- self.post(self.url, json=dict(
- name="Some Project",
- description="Some description",
- model=self.model.id,
- label=None,
- external=None,
- )
- )
- self.assertEqual(1, Project.query.count())
- project = Project.query.first()
- self.assertIsNotNone(project)
- self.assertIsNotNone(project.model)
- self.assertNotEqual(project.model.id, self.model.id,
- "The model shoud be copied")
- self.assertIsNone(project.label_provider)
- class ProjectDeletionTests(_BaseProjectTests):
- def setupModels(self):
- super().setupModels()
- self.project = Project.new(
- name="test_project",
- description="Project for a test case",
- model=self.model,
- root_folder="project_folder",
- external_data=False,
- data_folder="project_folder/data",)
- @property
- def url(self):
- return url_for("remove_project",
- project_id=self.project.id)
- def test_project_deletion(self):
- self.assertEqual(1, Project.query.count())
- self.post(self.url, json=dict(remove=True))
- self.assertEqual(0, Project.query.count())
- def test_project_deletion_without_flag(self):
- self.assertEqual(1, Project.query.count())
- self.post(self.url, json=dict(), status_code=400)
- self.assertEqual(1, Project.query.count())
- self.post(self.url, json=dict(remove=False), status_code=400)
- self.assertEqual(1, Project.query.count())
- def test_project_deletion_with_labels(self):
- self.assertEqual(1, Project.query.count())
- self.assertEqual(0, Label.query.count())
- for i in range(1, 11):
- self.project.create_label(name=f"Label_{i}")
- self.assertEqual(10, Label.query.count())
- self.post(self.url, json=dict(remove=True))
- self.assertEqual(0, Project.query.count())
- self.assertEqual(0, Label.query.count())
- class ProjectListTests(_BaseProjectTests):
- def test_list_projects(self):
- self.assertEqual(0, Project.query.count())
- for i in range(1, 11):
- Project.new(
- name=f"TestProject{i}",
- description=f"Project for a test case #{i}",
- model=self.model,
- root_folder=f"project_folder{i}",
- external_data=False,
- data_folder=f"project_folder{i}/data",
- )
- self.assertEqual(10, Project.query.count())
- response = self.get(url_for("list_projects"))
- self.assertTrue(response.is_json)
- content = response.json
- self.assertEqual(10, len(content))
- for entry in content:
- project = Project.query.get(entry["id"])
- self.assertIsNotNone(project)
- self.assertDictEqual(entry, project.serialize())
- def test_list_project_collections(self):
- project = Project.new(
- name="TestProject",
- description="Project for a test case",
- model=self.model,
- root_folder="project_folder",
- external_data=False,
- data_folder="project_folder/data",
- )
- self.assertEqual(0, Collection.query.count())
- for i in range(1, 11):
- project.create_collection(
- reference=f"collection_{i}",
- name=f"Some collection {i}",
- description=f"A description {i}",
- position=i,
- autoselect=i == 1
- )
- self.assertEqual(10, Collection.query.count())
- response = self.get(url_for("list_collections",
- project_id=project.id))
- self.assertTrue(response.is_json)
- content = response.json
- self.assertEqual(10, len(content))
- for entry in content:
- collection = Collection.query.get(entry["id"])
- self.assertIsNotNone(collection)
- self.assertDictEqual(entry, collection.serialize())
- def test_list_all_files(self):
- project = Project.new(
- name="TestProject",
- description="Project for a test case",
- model=self.model,
- root_folder="project_folder",
- external_data=False,
- data_folder="project_folder/data",
- )
- self.assertEqual(0, File.query.count())
- files = []
- for i in range(1, 11):
- file_uuid = str(uuid.uuid1())
- file, is_new = project.add_file(
- uuid=file_uuid,
- file_type="image",
- name=f"name{i}",
- filename=f"image_{i:03d}",
- extension=".jpg",
- size=32*1024,
- )
- self.assertTrue(is_new)
- files.append(file)
- self.assertEqual(10, File.query.count())
- response = self.get(url_for("list_all_files",
- project_id=project.id))
- self.assertTrue(response.is_json)
- _content = response.json
- count = _content["count"]
- content = _content["files"]
- self.assertEqual(10, count)
- self.assertEqual(10, len(content))
- for file, entry in zip(files, content):
- self.assertDictEqual(entry, file.serialize())
- def test_list_some_files(self):
- project = Project.new(
- name="TestProject",
- description="Project for a test case",
- model=self.model,
- root_folder="project_folder",
- external_data=False,
- data_folder="project_folder/data",
- )
- self.assertEqual(0, File.query.count())
- files = []
- for i in range(1, 11):
- file_uuid = str(uuid.uuid1())
- file, is_new = project.add_file(
- uuid=file_uuid,
- file_type="image",
- name=f"name{i}",
- filename=f"image_{i:03d}",
- extension=".jpg",
- size=32*1024,
- )
- self.assertTrue(is_new)
- files.append(file)
- self.assertEqual(10, File.query.count())
- for start, length in [(0, 5), (0, 15), (5, 3), (5, 8)]:
- response = self.get(url_for("list_files",
- project_id=project.id,
- start=start, length=length))
- self.assertTrue(response.is_json)
- _content = response.json
- count = _content["count"]
- content = _content["files"]
- self.assertEqual(len(files), count)
- self.assertEqual(min(len(files), start+length) - start, len(content))
- for file, entry in zip(files[start:start+length], content):
- self.assertDictEqual(entry, file.serialize())
- def test_list_collection_files_of_non_existing_collection(self):
- project = Project.new(
- name="TestProject",
- description="Project for a test case",
- model=self.model,
- root_folder="project_folder",
- external_data=False,
- data_folder="project_folder/data",
- )
- url = url_for("list_collection_files",
- project_id=project.id, collection_id=42,
- start=0, length=30)
- self.get(url, status_code=404)
- def test_list_collection_files(self):
- project = Project.new(
- name="TestProject",
- description="Project for a test case",
- model=self.model,
- root_folder="project_folder",
- external_data=False,
- data_folder="project_folder/data",
- )
- self.assertEqual(1, Project.query.count())
- collections = {}
- for i in range(1, 3):
- collection, is_new = project.create_collection(
- reference=f"collection_{i}",
- name=f"Some collection {i}",
- description=f"A description {i}",
- position=i,
- autoselect=i == 1
- )
- self.assertTrue(is_new)
- collection_files = []
- for j in range(1, 4):
- file_uuid = str(uuid.uuid1())
- file, is_new = collection.add_file(
- uuid=file_uuid,
- file_type="image",
- name=f"col_{i}_name{j}",
- filename=f"col_{i}_image_{j:03d}",
- extension=".jpg",
- size=32*1024,
- )
- self.assertTrue(is_new)
- collection_files.append(file)
- collections[collection.id] = collection_files
- files = []
- for j in range(1, 4):
- file_uuid = str(uuid.uuid1())
- file, is_new = project.add_file(
- uuid=file_uuid,
- file_type="image",
- name=f"name{j}",
- filename=f"image_{j:03d}",
- extension=".jpg",
- size=32*1024,
- )
- self.assertTrue(is_new)
- files.append(file)
- collections[0] = files
- self.assertEqual(2, Collection.query.filter(Collection.project_id==project.id).count())
- self.assertEqual(6, File.query.filter(
- File.project_id == project.id,
- File.collection_id != None,
- ).count())
- self.assertEqual(3, File.query.filter(
- File.project_id == project.id,
- File.collection_id == None,
- ).count())
- for collection_id, files in collections.items():
- for start, length in [(0, 5), (0, 15), (1, 3), (1, 8)]:
- response = self.get(url_for("list_collection_files",
- project_id=project.id, collection_id=collection_id,
- start=start, length=length))
- self.assertTrue(response.is_json)
- _content = response.json
- count = _content["count"]
- content = _content["files"]
- self.assertEqual(len(files), count)
- self.assertEqual(min(len(files), start+length) - start, len(content))
- for file, entry in zip(files[start:start+length], content):
- self.assertDictEqual(entry, file.serialize())
- def test_list_project_results(self):
- project = Project.new(
- name="test_project",
- description="Project for a test case",
- model=self.model,
- root_folder="project_folder",
- external_data=False,
- data_folder="project_folder/data",)
- url = url_for("get_project_results", project_id=42)
- self.get(url, status_code=404)
- files = []
- for i in range(1, 4):
- file_uuid = str(uuid.uuid1())
- file, is_new = project.add_file(
- uuid=file_uuid,
- file_type="image",
- name=f"name{i}",
- filename=f"image_{i:03d}",
- extension=".jpg",
- size=32*1024,
- )
- self.assertTrue(is_new)
- files.append(file)
- file.create_result(
- origin="user",
- result_type="bounding-box",
- label=None,
- data=dict(x=0, y=0, w=1, h=1)
- )
- self.assertEqual(3, File.query.count())
- self.assertEqual(3, Result.query.count())
- url = url_for("get_project_results", project_id=project.id)
- response = self.get(url)
- self.assertTrue(response.is_json)
- content = response.json
- self.assertTrue(project.files.count(), len(content))
- media_files = {f.filename: MediaFile(f, None) for f in project.files.all()}
- for entry in content:
- media_file = media_files[entry["filename"]]
- self.assertDictEqual(media_file.serialize(), entry)
- def test_get_project_model(self):
- project = Project.new(
- name="test_project",
- description="Project for a test case",
- model=self.model,
- root_folder="project_folder",
- external_data=False,
- data_folder="project_folder/data",)
- url = url_for("get_project_model", project_id=42)
- self.get(url, status_code=404)
- url = url_for("get_project_model", project_id=project.id)
- response = self.get(url)
- self.assertTrue(response.is_json)
- content = response.json
- self.assertDictEqual(project.model.serialize(), content)
- class ProjectEditTests(_BaseProjectTests):
- def setupModels(self):
- super().setupModels()
- self.project = Project.new(
- name="test_project",
- description="Project for a test case",
- model=self.model,
- root_folder="project_folder",
- external_data=False,
- data_folder="project_folder/data",)
- def test_name_edit(self):
- name = "new_name"
- url = url_for("edit_project_name", project_id=self.project.id)
- self.post(url, json=dict(name=name))
- self.assertEqual(name, self.project.name)
- def test_name_edit_missing_argument(self):
- url = url_for("edit_project_name", project_id=self.project.id)
- self.post(url, json=dict(), status_code=400)
- self.assertEqual("test_project", self.project.name)
- def test_description_edit(self):
- description = "New description for the project"
- url = url_for("edit_project_description", project_id=self.project.id)
- self.post(url, json=dict(description=description))
- self.assertEqual(description, self.project.description)
- def test_description_edit_missing_argument(self):
- url = url_for("edit_project_description", project_id=self.project.id)
- self.post(url, json=dict(), status_code=400)
- self.assertEqual("Project for a test case", self.project.description)
|