6
0

result.py 7.4 KB


  1. import click
  2. import flask
  3. import simplejson as json
  4. from flask.cli import AppGroup
  5. from tqdm import tqdm
  6. from pycs import app
  7. from pycs import database as db
  8. result_cli = AppGroup("result", short_help="Result operations")
  9. @result_cli.command("export")
  10. @click.argument("project_id")
  11. @click.argument("indent", required=False)
  12. @click.argument("output", required=False)
  13. def export(project_id, output, indent):
  14. """ Export results for a specific project or for all projects """
  15. if project_id == "all":
  16. projects = db.Project.query.all()
  17. app.logger.info(f"Exporting results for all projects ({len(projects)})!")
  18. if output is None:
  19. output = "output.json"
  20. else:
  21. project = db.Project.query.get(project_id)
  22. if project is None:
  23. app.logger.error(f"Could not find project with ID {project_id}!")
  24. return
  25. app.logger.info(f"Exporting results for project {project}!")
  26. projects = [project]
  27. if output is None:
  28. output = f"output_project_{int(project_id):04d}.json"
  29. app.logger.info(f"Exporting to {output}")
  30. results = []
  31. for project in projects:
  32. project_files = [
  33. dict(**f.serialize(),
  34. results=[
  35. dict(**r.serialize(), label=r.label.serialize() if r.label is not None else None)
  36. for r in f.results.all()
  37. ])
  38. for f in tqdm(project.files.all()) if f.results.count() != 0
  39. ]
  40. results.append(dict(
  41. project_id=project.id,
  42. files=project_files,
  43. labels=[lab.serialize() for lab in project.labels.all()],
  44. ))
  45. if indent is not None:
  46. indent = int(indent)
  47. with open(output, "w", encoding="utf-8") as out_f:
  48. flask.json.dump(results, out_f, app=app, indent=indent)
  49. @result_cli.command("restore")
  50. @click.argument("infile")
  51. @click.option("--dry-run", is_flag=True)
  52. def restore(infile, dry_run):
  53. with open(infile) as f:
  54. results = json.load(f)
  55. for project_results in results:
  56. project = db.Project.get_or_404(project_results["project_id"])
  57. for file_results in project_results["files"]:
  58. file = db.File.get_or_404(file_results["id"])
  59. assert file.path == file_results["path"]
  60. # first check for new and changed results
  61. for _result in file_results["results"]:
  62. if not _is_data_valid(**_result):
  63. continue
  64. result = get_result_or_none(file, **_result)
  65. user1 = _result["origin_user"]
  66. data1 = _result["data"]
  67. ref1 = (_result["label"] or {}).get("reference")
  68. # lab1 = (_result["label"] or {}).get("id")
  69. if result is None:
  70. # we have a new result entry
  71. if not dry_run:
  72. file.create_result(
  73. result_type="bounding-box",
  74. origin="user",
  75. origin_user=user1,
  76. label=ref1,
  77. data=data1,
  78. commit=True
  79. )
  80. print(" | ".join([
  81. f"Project #{project.id:< 6d}"
  82. f"File #{file.id:< 6d} [{file.name:^30s}]",
  83. "[New Result]",
  84. f"User: {user1 or '':<10s}",
  85. f"Data: {data1}, Label-Ref: {ref1}",
  86. ])
  87. )
  88. continue
  89. assert result.file_id == _result["file_id"]
  90. user0 = result.origin_user
  91. data0 = result.data
  92. ref0 = getattr(result.label, "reference", None)
  93. # lab0 = getattr(result.label, "id", None)
  94. is_same_data = _check_data(data0, data1)
  95. if is_same_data and (ref0 == ref1 or ref1 is None):
  96. # nothing to change
  97. continue
  98. print(" | ".join([
  99. f"Project #{project.id:< 6d}"
  100. f"File #{file.id:< 6d} [{file.name:^30s}]",
  101. ]), end=" | "
  102. )
  103. if not is_same_data:
  104. # data was updated
  105. print(" | ".join([
  106. "[Data updated]",
  107. f"User: {user1 or '':<10s}",
  108. f"Data: {data0} -> {data1}"
  109. ]), end=" | "
  110. )
  111. assert user1 is not None
  112. if not dry_run:
  113. result.origin_user = user1
  114. result.data = data1
  115. if ref0 != ref1:
  116. assert user1 is not None
  117. if not dry_run:
  118. result.origin_user = user1
  119. if ref1 is None:
  120. # label was deleted
  121. print("[Label Deleted]")
  122. if not dry_run:
  123. result.label_id = None
  124. else:
  125. # label was updated
  126. print(" | ".join([
  127. "[Label updated]",
  128. f"User: {user0 or '':<10s} -> {user1 or '':<10s}",
  129. f"{ref0 or 'UNK':<6s} -> {ref1 or 'UNK':<6s}"
  130. ])
  131. )
  132. label = project.label_by_reference(ref1)
  133. if not dry_run:
  134. result.label_id = label.id
  135. else:
  136. print()
  137. if not dry_run:
  138. result.commit()
  139. # then check for deleted results
  140. for result in file.results.all():
  141. if result.origin != "user" or result.type != "bounding-box":
  142. continue
  143. found = False
  144. for _result in file_results["results"]:
  145. if not _is_data_valid(**_result):
  146. continue
  147. if _check_data(result.data, _result["data"]):
  148. found = True
  149. break
  150. if not found:
  151. print(" | ".join([
  152. f"Project #{project.id:< 6d}"
  153. f"File #{file.id:< 6d} [{file.name:^30s}]",
  154. "[Result deleted]",
  155. f"{result.data}",
  156. f"{result.label}",
  157. ])
  158. )
  159. if not dry_run:
  160. result.delete()
  161. def _is_data_valid(*, data, type, origin, **kwargs):
  162. wh = (None, None) if data is None else (data["w"], data["h"])
  163. return (type != "labeled-image" and
  164. origin == "user" and
  165. 0 not in wh)
  166. def _check_data(data0, data1):
  167. if None in (data0, data1):
  168. return data0 == data1 == None
  169. for key in data0:
  170. if data1.get(key) != data0.get(key):
  171. return False
  172. return True
  173. def get_result_or_none(file: db.File, id: int, data: dict, **kwargs):
  174. result = db.Result.query.filter(
  175. db.Result.id==id, db.Result.file_id==file.id).one_or_none()
  176. if result is not None:
  177. return result
  178. for other_results in file.results.all():
  179. if _check_data(data, other_results.data):
  180. # import pdb; pdb.set_trace()
  181. return other_results