Browse Source

cleaned info file. added script for part regrouping

Dimitri Korsch 3 years ago
parent
commit
9f2b863995
2 changed files with 180 additions and 41 deletions
  1. 180 0
      scripts/group_parts.py
  2. 0 41
      scripts/info_files/info.yml

+ 180 - 0
scripts/group_parts.py

@@ -0,0 +1,180 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+import logging
+import numpy as np
+import simplejson as json
+
+from tqdm import tqdm
+from pathlib import Path
+from collections import OrderedDict
+
+from cvargparse import Arg
+from cvargparse import BaseParser
+
+PartDtype = np.dtype([
+	("img_id", "U255"),
+	("part_id", np.int32),
+	("x", np.int32),
+	("y", np.int32),
+	("visiblity", np.int32),
+])
+
+parser = BaseParser([
+	Arg("parts",
+		help="parts file"),
+	Arg("part_names",
+		help="part names"),
+	Arg("grouping_file",
+		help="grouping file in JSON format"),
+	Arg("output_folder",
+		help="output folder"),
+
+	Arg("--id_shift", type=int, default=0),
+])
+
+def _read_part_names(fname):
+	id_to_name = OrderedDict()
+	name_to_id = OrderedDict()
+	with open(fname, "r") as f:
+		for line in f:
+			line = line.rstrip()
+			if not line:
+				continue
+			idx, _, name = line.partition(" ")
+			id_to_name[int(idx)] = name
+			name_to_id[name] = int(idx)
+	return id_to_name, name_to_id
+
+
+def main(args):
+	logging.info("reading content")
+	with open(args.grouping_file, "r") as f:
+		grouping = json.load(f)
+
+	id_to_name, name_to_id = _read_part_names(args.part_names)
+	parts = np.loadtxt(args.parts, dtype=PartDtype)
+
+	n_samples = len(np.unique(parts["img_id"]))
+	n_parts = len(parts) // n_samples
+	n_new_parts = len(grouping)
+
+	assert n_samples * n_parts == len(parts), \
+		"n_samples and n_parts was calculated incorrectly!"
+
+	parts = parts.reshape(n_samples, n_parts)
+	logging.info(f"Got {parts.size:,d} annotations for {n_samples:,d} samples with {n_parts:,d} parts each")
+
+	new_parts = np.zeros((n_samples, n_new_parts), dtype=PartDtype)
+	logging.info(f"Created {new_parts.size:,d} new annotations for {n_samples:,d} samples with {n_new_parts:,d} parts each")
+
+	logging.info("Setting sample ids")
+
+	new_parts["img_id"] = parts["img_id"][:, :n_new_parts]
+
+	logging.info(f"Setting part ids (id shift={args.id_shift})")
+	new_parts["part_id"] = np.arange(n_new_parts)[None] + args.id_shift
+
+	new_names = []
+	logging.info("Starting part re-grouping")
+
+	for i, (group_name, part_names) in enumerate(grouping.items()):
+		new_names.append((i + args.id_shift, group_name))
+
+		part_idxs = np.array([name_to_id[n] for n in part_names])
+		logging.info(f"Grouping {part_idxs} to {new_names[-1]}")
+		logging.info(f"Grouping {[id_to_name[p] for p in part_idxs]} to {new_names[-1]}")
+
+		part_idxs -= args.id_shift
+
+
+		# determine average position
+		group_positions = parts[:, part_idxs]
+		import pdb; pdb.set_trace()
+		n_visible = group_positions["visiblity"].sum(axis=1)
+
+		for coord in ["x", "y"]:
+			new_pos = group_positions[coord].sum(axis=1) // n_visible
+			new_pos[np.isnan(new_pos)] = 0
+
+			new_parts[:, i][coord] = new_pos.astype(np.int32)
+
+		# set visibility: a group is visible, when at least one part is present
+		new_parts[:, i]["visiblity"] = group_positions["visiblity"].max(axis=1)
+
+	# hack aroung divide by 0
+	output = Path(args.output_folder)
+
+	if not output.is_dir():
+		logging.info("creating output folder")
+		output.mkdir()
+
+	logging.info("writing groups to \"{}\"".format(args.output_folder))
+
+	np.savetxt(Path(args.output_folder, "parts.txt"),
+		np.array(new_names, dtype=[("id", np.int32), ("name", "U255")]),
+		fmt="%d %s")
+
+	np.savetxt(Path(args.output_folder, "part_locs.txt"), new_parts.reshape(-1),
+		fmt="%s %d %d %d %d")
+
+main(parser.parse_args())
+
+"""
+Examples:
+
+#### CUB-200-2011 ###
+> cat ~/Data/DATASETS/birds/cub200/groups.json
+# {
+#   "head": [
+#     "beak",
+#     "crown",
+#     "forehead",
+#     "left eye",
+#     "nape",
+#     "right eye",
+#     "throat"
+#   ],
+#   "body": [
+#     "back",
+#     "belly",
+#     "breast",
+#     "left wing",
+#     "right wing"
+#   ],
+#   "tail": [
+#     "tail"
+#   ],
+#   "legs": [
+#     "left leg",
+#     "right leg"
+#   ]
+# }
+> cd ~/Data/DATASETS/birds/cub200/
+> <path_to_script>.py GT/parts/{part_locs,parts}.txt groups.json GT2/parts --id_shift 1
+
+#### NA-Birds ###
+> cat ~/Data/DATASETS/nabirds/cub200/groups.json
+# {
+#   "head": [
+#     "bill",
+#     "crown",
+#     "nape",
+#     "left eye",
+#     "right eye"
+#   ],
+#   "body": [
+#     "back",
+#     "belly",
+#     "breast",
+#     "left wing",
+#     "right wing"
+#   ],
+#   "tail": [
+#     "tail"
+#   ]
+# }
+> cd ~/Data/DATASETS/birds/nabirds/
+> <path_to_script>.py GT/parts/{part_locs,parts}.txt groups.json GT2/parts
+
+"""

+ 0 - 41
scripts/info_files/info.yml

@@ -22,47 +22,6 @@ MODELS:
     weights:
       imagenet: model.npz
 
-  # efficientnet:    &efficientnet
-  #   folder: efficientnet
-  #   class_key: efficientnet
-  #   weights: model.imagenet.npz
-
-  # inception_inat:    &inception_inat
-  #   folder: inception
-  #   class_key: inception
-  #   weights: model.inat.ckpt.npz
-
-  # inception_imagenet:    &inception_inet
-  #   folder: inception
-  #   class_key: inception
-  #   weights: model.imagenet.ckpt.npz
-
-  # inception:
-  #   <<: *inception_inat
-
-  # inception_tf_inat:  &inception_tf_inat
-  #   folder: inception_tf
-  #   class_key: inception_tf
-  #   weights: inception_v3_iNat_299.ckpt
-
-  # inception_tf_inet:  &inception_tf_inet
-  #   folder: inception_tf
-  #   class_key: inception_tf
-  #   weights: inception_v3_ILSVRC_299.ckpt
-
-  # inception_tf:  &inception_tf
-  #   <<: *inception_tf_inat
-
-  # resnet:       &resnet50
-  #   folder: resnet
-  #   class_key: resnet
-  #   weights: model.npz
-
-  # vgg19:       &vgg19
-  #   folder: vgg19
-  #   class_key: vgg19
-  #   weights: model.npz
-
 ############ Existing Datasets
 DATASETS: