group_parts.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import logging
  4. import numpy as np
  5. import simplejson as json
  6. from tqdm import tqdm
  7. from pathlib import Path
  8. from collections import OrderedDict
  9. from cvargparse import Arg
  10. from cvargparse import BaseParser
  11. PartDtype = np.dtype([
  12. ("img_id", "U255"),
  13. ("part_id", np.int32),
  14. ("x", np.int32),
  15. ("y", np.int32),
  16. ("visiblity", np.int32),
  17. ])
  18. parser = BaseParser([
  19. Arg("parts",
  20. help="parts file"),
  21. Arg("part_names",
  22. help="part names"),
  23. Arg("grouping_file",
  24. help="grouping file in JSON format"),
  25. Arg("output_folder",
  26. help="output folder"),
  27. Arg("--id_shift", type=int, default=0),
  28. ])
  29. def _read_part_names(fname):
  30. id_to_name = OrderedDict()
  31. name_to_id = OrderedDict()
  32. with open(fname, "r") as f:
  33. for line in f:
  34. line = line.rstrip()
  35. if not line:
  36. continue
  37. idx, _, name = line.partition(" ")
  38. id_to_name[int(idx)] = name
  39. name_to_id[name] = int(idx)
  40. return id_to_name, name_to_id
  41. def main(args):
  42. logging.info("reading content")
  43. with open(args.grouping_file, "r") as f:
  44. grouping = json.load(f)
  45. id_to_name, name_to_id = _read_part_names(args.part_names)
  46. parts = np.loadtxt(args.parts, dtype=PartDtype)
  47. n_samples = len(np.unique(parts["img_id"]))
  48. n_parts = len(parts) // n_samples
  49. n_new_parts = len(grouping)
  50. assert n_samples * n_parts == len(parts), \
  51. "n_samples and n_parts was calculated incorrectly!"
  52. parts = parts.reshape(n_samples, n_parts)
  53. logging.info(f"Got {parts.size:,d} annotations for {n_samples:,d} samples with {n_parts:,d} parts each")
  54. new_parts = np.zeros((n_samples, n_new_parts), dtype=PartDtype)
  55. logging.info(f"Created {new_parts.size:,d} new annotations for {n_samples:,d} samples with {n_new_parts:,d} parts each")
  56. logging.info("Setting sample ids")
  57. new_parts["img_id"] = parts["img_id"][:, :n_new_parts]
  58. logging.info(f"Setting part ids (id shift={args.id_shift})")
  59. new_parts["part_id"] = np.arange(n_new_parts)[None] + args.id_shift
  60. new_names = []
  61. logging.info("Starting part re-grouping")
  62. for i, (group_name, part_names) in enumerate(grouping.items()):
  63. new_names.append((i + args.id_shift, group_name))
  64. part_idxs = np.array([name_to_id[n] for n in part_names])
  65. logging.info(f"Grouping {part_idxs} to {new_names[-1]}")
  66. logging.info(f"Grouping {[id_to_name[p] for p in part_idxs]} to {new_names[-1]}")
  67. part_idxs -= args.id_shift
  68. # determine average position
  69. group_positions = parts[:, part_idxs]
  70. import pdb; pdb.set_trace()
  71. n_visible = group_positions["visiblity"].sum(axis=1)
  72. for coord in ["x", "y"]:
  73. new_pos = group_positions[coord].sum(axis=1) // n_visible
  74. new_pos[np.isnan(new_pos)] = 0
  75. new_parts[:, i][coord] = new_pos.astype(np.int32)
  76. # set visibility: a group is visible, when at least one part is present
  77. new_parts[:, i]["visiblity"] = group_positions["visiblity"].max(axis=1)
  78. # hack aroung divide by 0
  79. output = Path(args.output_folder)
  80. if not output.is_dir():
  81. logging.info("creating output folder")
  82. output.mkdir()
  83. logging.info("writing groups to \"{}\"".format(args.output_folder))
  84. np.savetxt(Path(args.output_folder, "parts.txt"),
  85. np.array(new_names, dtype=[("id", np.int32), ("name", "U255")]),
  86. fmt="%d %s")
  87. np.savetxt(Path(args.output_folder, "part_locs.txt"), new_parts.reshape(-1),
  88. fmt="%s %d %d %d %d")
  89. main(parser.parse_args())
  90. """
  91. Examples:
  92. #### CUB-200-2011 ###
  93. > cat ~/Data/DATASETS/birds/cub200/groups.json
  94. # {
  95. # "head": [
  96. # "beak",
  97. # "crown",
  98. # "forehead",
  99. # "left eye",
  100. # "nape",
  101. # "right eye",
  102. # "throat"
  103. # ],
  104. # "body": [
  105. # "back",
  106. # "belly",
  107. # "breast",
  108. # "left wing",
  109. # "right wing"
  110. # ],
  111. # "tail": [
  112. # "tail"
  113. # ],
  114. # "legs": [
  115. # "left leg",
  116. # "right leg"
  117. # ]
  118. # }
  119. > cd ~/Data/DATASETS/birds/cub200/
  120. > <path_to_script>.py GT/parts/{part_locs,parts}.txt groups.json GT2/parts --id_shift 1
  121. #### NA-Birds ###
  122. > cat ~/Data/DATASETS/nabirds/cub200/groups.json
  123. # {
  124. # "head": [
  125. # "bill",
  126. # "crown",
  127. # "nape",
  128. # "left eye",
  129. # "right eye"
  130. # ],
  131. # "body": [
  132. # "back",
  133. # "belly",
  134. # "breast",
  135. # "left wing",
  136. # "right wing"
  137. # ],
  138. # "tail": [
  139. # "tail"
  140. # ]
  141. # }
  142. > cd ~/Data/DATASETS/birds/nabirds/
  143. > <path_to_script>.py GT/parts/{part_locs,parts}.txt groups.json GT2/parts
  144. """