6
0

project_tests.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. import uuid
  2. from flask import url_for
  3. from pycs.database.Collection import Collection
  4. from pycs.database.File import File
  5. from pycs.database.Label import Label
  6. from pycs.database.Model import Model
  7. from pycs.database.Project import Project
  8. from pycs.database.Result import Result
  9. from pycs.interfaces.MediaFile import MediaFile
  10. from tests.base import BaseTestCase
  11. class _BaseProjectTests(BaseTestCase):
  12. def setupModels(self):
  13. model = Model.new(
  14. commit=False,
  15. name="TestModel",
  16. description="Model for a test case",
  17. root_folder="models/fixed_model",
  18. )
  19. model.supports = ["labeled-image"]
  20. model.flush()
  21. self.model = model
  22. def tearDown(self):
  23. self.model.delete()
  24. super().tearDown()
  25. class ProjectCreationTests(_BaseProjectTests):
  26. def setUp(self):
  27. super().setUp()
  28. self.url = url_for("create_project")
  29. def test_project_creation_without_name(self):
  30. self.assertEqual(0, Project.query.count())
  31. self.post(self.url, json=dict(
  32. # name="Some Project",
  33. description="Some description",
  34. model=self.model.id,
  35. label=None,
  36. external=None,
  37. ),
  38. status_code=400,
  39. )
  40. self.assertEqual(0, Project.query.count())
  41. def test_project_creation_without_description(self):
  42. self.assertEqual(0, Project.query.count())
  43. self.post(self.url, json=dict(
  44. name="Some Project",
  45. # description="Some description",
  46. model=self.model.id,
  47. label=None,
  48. external=None,
  49. ),
  50. status_code=400,
  51. )
  52. self.assertEqual(0, Project.query.count())
  53. def test_project_creation(self):
  54. self.assertEqual(0, Project.query.count())
  55. self.post(self.url, json=dict(
  56. name="Some Project",
  57. description="Some description",
  58. model=self.model.id,
  59. label=None,
  60. external=None,
  61. )
  62. )
  63. self.assertEqual(1, Project.query.count())
  64. project = Project.query.first()
  65. self.assertIsNotNone(project)
  66. self.assertIsNotNone(project.model)
  67. self.assertIsNone(project.label_provider)
  68. class ProjectDeletionTests(_BaseProjectTests):
  69. def setupModels(self):
  70. super().setupModels()
  71. self.project = Project.new(
  72. name="test_project",
  73. description="Project for a test case",
  74. model=self.model,
  75. root_folder="project_folder",
  76. external_data=False,
  77. data_folder="project_folder/data",)
  78. @property
  79. def url(self):
  80. return url_for("remove_project",
  81. project_id=self.project.id)
  82. def test_project_deletion(self):
  83. self.assertEqual(1, Project.query.count())
  84. self.post(self.url, json=dict(remove=True))
  85. self.assertEqual(0, Project.query.count())
  86. def test_project_deletion_without_flag(self):
  87. self.assertEqual(1, Project.query.count())
  88. self.post(self.url, json=dict(), status_code=400)
  89. self.assertEqual(1, Project.query.count())
  90. self.post(self.url, json=dict(remove=False), status_code=400)
  91. self.assertEqual(1, Project.query.count())
  92. def test_project_deletion_with_labels(self):
  93. self.assertEqual(1, Project.query.count())
  94. self.assertEqual(0, Label.query.count())
  95. for i in range(1, 11):
  96. self.project.create_label(name=f"Label_{i}")
  97. self.assertEqual(10, Label.query.count())
  98. self.post(self.url, json=dict(remove=True))
  99. self.assertEqual(0, Project.query.count())
  100. self.assertEqual(0, Label.query.count())
  101. class ProjectListTests(_BaseProjectTests):
  102. def test_list_projects(self):
  103. self.assertEqual(0, Project.query.count())
  104. for i in range(1, 11):
  105. Project.new(
  106. name=f"TestProject{i}",
  107. description=f"Project for a test case #{i}",
  108. model=self.model,
  109. root_folder=f"project_folder{i}",
  110. external_data=False,
  111. data_folder=f"project_folder{i}/data",
  112. )
  113. self.assertEqual(10, Project.query.count())
  114. response = self.get(url_for("list_projects"))
  115. self.assertTrue(response.is_json)
  116. content = response.json
  117. self.assertEqual(10, len(content))
  118. for entry in content:
  119. project = Project.query.get(entry["id"])
  120. self.assertIsNotNone(project)
  121. self.assertDictEqual(entry, project.serialize())
  122. def test_list_project_collections(self):
  123. project = Project.new(
  124. name="TestProject",
  125. description="Project for a test case",
  126. model=self.model,
  127. root_folder="project_folder",
  128. external_data=False,
  129. data_folder="project_folder/data",
  130. )
  131. self.assertEqual(0, Collection.query.count())
  132. for i in range(1, 11):
  133. project.create_collection(
  134. reference=f"collection_{i}",
  135. name=f"Some collection {i}",
  136. description=f"A description {i}",
  137. position=i,
  138. autoselect=i == 1
  139. )
  140. self.assertEqual(10, Collection.query.count())
  141. response = self.get(url_for("list_collections",
  142. project_id=project.id))
  143. self.assertTrue(response.is_json)
  144. content = response.json
  145. self.assertEqual(10, len(content))
  146. for entry in content:
  147. collection = Collection.query.get(entry["id"])
  148. self.assertIsNotNone(collection)
  149. self.assertDictEqual(entry, collection.serialize())
  150. def test_list_all_files(self):
  151. project = Project.new(
  152. name="TestProject",
  153. description="Project for a test case",
  154. model=self.model,
  155. root_folder="project_folder",
  156. external_data=False,
  157. data_folder="project_folder/data",
  158. )
  159. self.assertEqual(0, File.query.count())
  160. files = []
  161. for i in range(1, 11):
  162. file_uuid = str(uuid.uuid1())
  163. file, is_new = project.add_file(
  164. uuid=file_uuid,
  165. file_type="image",
  166. name=f"name{i}",
  167. filename=f"image_{i:03d}",
  168. extension=".jpg",
  169. size=32*1024,
  170. )
  171. self.assertTrue(is_new)
  172. files.append(file)
  173. self.assertEqual(10, File.query.count())
  174. response = self.get(url_for("list_all_files",
  175. project_id=project.id))
  176. self.assertTrue(response.is_json)
  177. _content = response.json
  178. count = _content["count"]
  179. content = _content["files"]
  180. self.assertEqual(10, count)
  181. self.assertEqual(10, len(content))
  182. for file, entry in zip(files, content):
  183. self.assertDictEqual(entry, file.serialize())
  184. def test_list_some_files(self):
  185. project = Project.new(
  186. name="TestProject",
  187. description="Project for a test case",
  188. model=self.model,
  189. root_folder="project_folder",
  190. external_data=False,
  191. data_folder="project_folder/data",
  192. )
  193. self.assertEqual(0, File.query.count())
  194. files = []
  195. for i in range(1, 11):
  196. file_uuid = str(uuid.uuid1())
  197. file, is_new = project.add_file(
  198. uuid=file_uuid,
  199. file_type="image",
  200. name=f"name{i}",
  201. filename=f"image_{i:03d}",
  202. extension=".jpg",
  203. size=32*1024,
  204. )
  205. self.assertTrue(is_new)
  206. files.append(file)
  207. self.assertEqual(10, File.query.count())
  208. for start, length in [(0, 5), (0, 15), (5, 3), (5, 8)]:
  209. response = self.get(url_for("list_files",
  210. project_id=project.id,
  211. start=start, length=length))
  212. self.assertTrue(response.is_json)
  213. _content = response.json
  214. count = _content["count"]
  215. content = _content["files"]
  216. self.assertEqual(len(files), count)
  217. self.assertEqual(min(len(files), start+length) - start, len(content))
  218. for file, entry in zip(files[start:start+length], content):
  219. self.assertDictEqual(entry, file.serialize())
  220. def test_list_collection_files_of_non_existing_collection(self):
  221. project = Project.new(
  222. name="TestProject",
  223. description="Project for a test case",
  224. model=self.model,
  225. root_folder="project_folder",
  226. external_data=False,
  227. data_folder="project_folder/data",
  228. )
  229. url = url_for("list_collection_files",
  230. project_id=project.id, collection_id=42,
  231. start=0, length=30)
  232. self.get(url, status_code=404)
  233. def test_list_collection_files(self):
  234. project = Project.new(
  235. name="TestProject",
  236. description="Project for a test case",
  237. model=self.model,
  238. root_folder="project_folder",
  239. external_data=False,
  240. data_folder="project_folder/data",
  241. )
  242. self.assertEqual(1, Project.query.count())
  243. collections = {}
  244. for i in range(1, 3):
  245. collection, is_new = project.create_collection(
  246. reference=f"collection_{i}",
  247. name=f"Some collection {i}",
  248. description=f"A description {i}",
  249. position=i,
  250. autoselect=i == 1
  251. )
  252. self.assertTrue(is_new)
  253. collection_files = []
  254. for j in range(1, 4):
  255. file_uuid = str(uuid.uuid1())
  256. file, is_new = collection.add_file(
  257. uuid=file_uuid,
  258. file_type="image",
  259. name=f"col_{i}_name{j}",
  260. filename=f"col_{i}_image_{j:03d}",
  261. extension=".jpg",
  262. size=32*1024,
  263. )
  264. self.assertTrue(is_new)
  265. collection_files.append(file)
  266. collections[collection.id] = collection_files
  267. files = []
  268. for j in range(1, 4):
  269. file_uuid = str(uuid.uuid1())
  270. file, is_new = project.add_file(
  271. uuid=file_uuid,
  272. file_type="image",
  273. name=f"name{j}",
  274. filename=f"image_{j:03d}",
  275. extension=".jpg",
  276. size=32*1024,
  277. )
  278. self.assertTrue(is_new)
  279. files.append(file)
  280. collections[0] = files
  281. self.assertEqual(2, Collection.query.filter(Collection.project_id==project.id).count())
  282. self.assertEqual(6, File.query.filter(
  283. File.project_id == project.id,
  284. File.collection_id != None,
  285. ).count())
  286. self.assertEqual(3, File.query.filter(
  287. File.project_id == project.id,
  288. File.collection_id == None,
  289. ).count())
  290. for collection_id, files in collections.items():
  291. for start, length in [(0, 5), (0, 15), (1, 3), (1, 8)]:
  292. response = self.get(url_for("list_collection_files",
  293. project_id=project.id, collection_id=collection_id,
  294. start=start, length=length))
  295. self.assertTrue(response.is_json)
  296. _content = response.json
  297. count = _content["count"]
  298. content = _content["files"]
  299. self.assertEqual(len(files), count)
  300. self.assertEqual(min(len(files), start+length) - start, len(content))
  301. for file, entry in zip(files[start:start+length], content):
  302. self.assertDictEqual(entry, file.serialize())
  303. def test_list_project_results(self):
  304. project = Project.new(
  305. name="test_project",
  306. description="Project for a test case",
  307. model=self.model,
  308. root_folder="project_folder",
  309. external_data=False,
  310. data_folder="project_folder/data",)
  311. url = url_for("get_project_results", project_id=42)
  312. self.get(url, status_code=404)
  313. files = []
  314. for i in range(1, 4):
  315. file_uuid = str(uuid.uuid1())
  316. file, is_new = project.add_file(
  317. uuid=file_uuid,
  318. file_type="image",
  319. name=f"name{i}",
  320. filename=f"image_{i:03d}",
  321. extension=".jpg",
  322. size=32*1024,
  323. )
  324. self.assertTrue(is_new)
  325. files.append(file)
  326. file.create_result(
  327. origin="user",
  328. result_type="bounding-box",
  329. label=None,
  330. data=dict(x=0, y=0, w=1, h=1)
  331. )
  332. self.assertEqual(3, File.query.count())
  333. self.assertEqual(3, Result.query.count())
  334. url = url_for("get_project_results", project_id=project.id)
  335. response = self.get(url)
  336. self.assertTrue(response.is_json)
  337. content = response.json
  338. self.assertTrue(project.files.count(), len(content))
  339. media_files = {f.filename: MediaFile(f, None) for f in project.files.all()}
  340. for entry in content:
  341. media_file = media_files[entry["filename"]]
  342. self.assertDictEqual(media_file.serialize(), entry)
  343. def test_get_project_model(self):
  344. project = Project.new(
  345. name="test_project",
  346. description="Project for a test case",
  347. model=self.model,
  348. root_folder="project_folder",
  349. external_data=False,
  350. data_folder="project_folder/data",)
  351. url = url_for("get_project_model", project_id=42)
  352. self.get(url, status_code=404)
  353. url = url_for("get_project_model", project_id=project.id)
  354. response = self.get(url)
  355. self.assertTrue(response.is_json)
  356. content = response.json
  357. self.assertDictEqual(project.model.serialize(), content)
  358. class ProjectEditTests(_BaseProjectTests):
  359. def setupModels(self):
  360. super().setupModels()
  361. self.project = Project.new(
  362. name="test_project",
  363. description="Project for a test case",
  364. model=self.model,
  365. root_folder="project_folder",
  366. external_data=False,
  367. data_folder="project_folder/data",)
  368. def test_name_edit(self):
  369. name = "new_name"
  370. url = url_for("edit_project_name", project_id=self.project.id)
  371. self.post(url, json=dict(name=name))
  372. self.assertEqual(name, self.project.name)
  373. def test_name_edit_missing_argument(self):
  374. url = url_for("edit_project_name", project_id=self.project.id)
  375. self.post(url, json=dict(), status_code=400)
  376. self.assertEqual("test_project", self.project.name)
  377. def test_description_edit(self):
  378. description = "New description for the project"
  379. url = url_for("edit_project_description", project_id=self.project.id)
  380. self.post(url, json=dict(description=description))
  381. self.assertEqual(description, self.project.description)
  382. def test_description_edit_missing_argument(self):
  383. url = url_for("edit_project_description", project_id=self.project.id)
  384. self.post(url, json=dict(), status_code=400)
  385. self.assertEqual("Project for a test case", self.project.description)