#!/usr/bin/env python if __name__ != '__main__': raise Exception("Do not import me!") import logging import numpy as np import simplejson as json from tqdm import tqdm from pathlib import Path from collections import OrderedDict from cvargparse import Arg from cvargparse import BaseParser PartDtype = np.dtype([ ("img_id", "U255"), ("part_id", np.int32), ("x", np.int32), ("y", np.int32), ("visiblity", np.int32), ]) parser = BaseParser([ Arg("parts", help="parts file"), Arg("part_names", help="part names"), Arg("grouping_file", help="grouping file in JSON format"), Arg("output_folder", help="output folder"), Arg("--id_shift", type=int, default=0), ]) def _read_part_names(fname): id_to_name = OrderedDict() name_to_id = OrderedDict() with open(fname, "r") as f: for line in f: line = line.rstrip() if not line: continue idx, _, name = line.partition(" ") id_to_name[int(idx)] = name name_to_id[name] = int(idx) return id_to_name, name_to_id def main(args): logging.info("reading content") with open(args.grouping_file, "r") as f: grouping = json.load(f) id_to_name, name_to_id = _read_part_names(args.part_names) parts = np.loadtxt(args.parts, dtype=PartDtype) n_samples = len(np.unique(parts["img_id"])) n_parts = len(parts) // n_samples n_new_parts = len(grouping) assert n_samples * n_parts == len(parts), \ "n_samples and n_parts was calculated incorrectly!" parts = parts.reshape(n_samples, n_parts) logging.info(f"Got {parts.size:,d} annotations for {n_samples:,d} samples with {n_parts:,d} parts each") new_parts = np.zeros((n_samples, n_new_parts), dtype=PartDtype) logging.info(f"Created {new_parts.size:,d} new annotations for {n_samples:,d} samples with {n_new_parts:,d} parts each") logging.info("Setting sample ids") new_parts["img_id"] = parts["img_id"][:, :n_new_parts] logging.info(f"Setting part ids (id shift={args.id_shift})") new_parts["part_id"] = np.arange(n_new_parts)[None] + args.id_shift new_names = [] logging.info("Starting part re-grouping") for i, (group_name, part_names) in enumerate(grouping.items()): new_names.append((i + args.id_shift, group_name)) part_idxs = np.array([name_to_id[n] for n in part_names]) logging.info(f"Grouping {part_idxs} to {new_names[-1]}") logging.info(f"Grouping {[id_to_name[p] for p in part_idxs]} to {new_names[-1]}") part_idxs -= args.id_shift # determine average position group_positions = parts[:, part_idxs] import pdb; pdb.set_trace() n_visible = group_positions["visiblity"].sum(axis=1) for coord in ["x", "y"]: new_pos = group_positions[coord].sum(axis=1) // n_visible new_pos[np.isnan(new_pos)] = 0 new_parts[:, i][coord] = new_pos.astype(np.int32) # set visibility: a group is visible, when at least one part is present new_parts[:, i]["visiblity"] = group_positions["visiblity"].max(axis=1) # hack aroung divide by 0 output = Path(args.output_folder) if not output.is_dir(): logging.info("creating output folder") output.mkdir() logging.info("writing groups to \"{}\"".format(args.output_folder)) np.savetxt(Path(args.output_folder, "parts.txt"), np.array(new_names, dtype=[("id", np.int32), ("name", "U255")]), fmt="%d %s") np.savetxt(Path(args.output_folder, "part_locs.txt"), new_parts.reshape(-1), fmt="%s %d %d %d %d") main(parser.parse_args()) """ Examples: #### CUB-200-2011 ### > cat ~/Data/DATASETS/birds/cub200/groups.json # { # "head": [ # "beak", # "crown", # "forehead", # "left eye", # "nape", # "right eye", # "throat" # ], # "body": [ # "back", # "belly", # "breast", # "left wing", # "right wing" # ], # "tail": [ # "tail" # ], # "legs": [ # "left leg", # "right leg" # ] # } > cd ~/Data/DATASETS/birds/cub200/ > .py GT/parts/{part_locs,parts}.txt groups.json GT2/parts --id_shift 1 #### NA-Birds ### > cat ~/Data/DATASETS/nabirds/cub200/groups.json # { # "head": [ # "bill", # "crown", # "nape", # "left eye", # "right eye" # ], # "body": [ # "back", # "belly", # "breast", # "left wing", # "right wing" # ], # "tail": [ # "tail" # ] # } > cd ~/Data/DATASETS/birds/nabirds/ > .py GT/parts/{part_locs,parts}.txt groups.json GT2/parts """