#!/usr/bin/env python
if __name__ != '__main__': raise Exception("Do not import me!")

import logging
import numpy as np
import simplejson as json

from tqdm import tqdm
from pathlib import Path
from collections import OrderedDict

from cvargparse import Arg
from cvargparse import BaseParser

PartDtype = np.dtype([
	("img_id", "U255"),
	("part_id", np.int32),
	("x", np.int32),
	("y", np.int32),
	("visiblity", np.int32),
])

parser = BaseParser([
	Arg("parts",
		help="parts file"),
	Arg("part_names",
		help="part names"),
	Arg("grouping_file",
		help="grouping file in JSON format"),
	Arg("output_folder",
		help="output folder"),

	Arg("--id_shift", type=int, default=0),
])

def _read_part_names(fname):
	id_to_name = OrderedDict()
	name_to_id = OrderedDict()
	with open(fname, "r") as f:
		for line in f:
			line = line.rstrip()
			if not line:
				continue
			idx, _, name = line.partition(" ")
			id_to_name[int(idx)] = name
			name_to_id[name] = int(idx)
	return id_to_name, name_to_id


def main(args):
	logging.info("reading content")
	with open(args.grouping_file, "r") as f:
		grouping = json.load(f)

	id_to_name, name_to_id = _read_part_names(args.part_names)
	parts = np.loadtxt(args.parts, dtype=PartDtype)

	n_samples = len(np.unique(parts["img_id"]))
	n_parts = len(parts) // n_samples
	n_new_parts = len(grouping)

	assert n_samples * n_parts == len(parts), \
		"n_samples and n_parts was calculated incorrectly!"

	parts = parts.reshape(n_samples, n_parts)
	logging.info(f"Got {parts.size:,d} annotations for {n_samples:,d} samples with {n_parts:,d} parts each")

	new_parts = np.zeros((n_samples, n_new_parts), dtype=PartDtype)
	logging.info(f"Created {new_parts.size:,d} new annotations for {n_samples:,d} samples with {n_new_parts:,d} parts each")

	logging.info("Setting sample ids")

	new_parts["img_id"] = parts["img_id"][:, :n_new_parts]

	logging.info(f"Setting part ids (id shift={args.id_shift})")
	new_parts["part_id"] = np.arange(n_new_parts)[None] + args.id_shift

	new_names = []
	logging.info("Starting part re-grouping")

	for i, (group_name, part_names) in enumerate(grouping.items()):
		new_names.append((i + args.id_shift, group_name))

		part_idxs = np.array([name_to_id[n] for n in part_names])
		logging.info(f"Grouping {part_idxs} to {new_names[-1]}")
		logging.info(f"Grouping {[id_to_name[p] for p in part_idxs]} to {new_names[-1]}")

		part_idxs -= args.id_shift


		# determine average position
		group_positions = parts[:, part_idxs]
		import pdb; pdb.set_trace()
		n_visible = group_positions["visiblity"].sum(axis=1)

		for coord in ["x", "y"]:
			new_pos = group_positions[coord].sum(axis=1) // n_visible
			new_pos[np.isnan(new_pos)] = 0

			new_parts[:, i][coord] = new_pos.astype(np.int32)

		# set visibility: a group is visible, when at least one part is present
		new_parts[:, i]["visiblity"] = group_positions["visiblity"].max(axis=1)

	# hack aroung divide by 0
	output = Path(args.output_folder)

	if not output.is_dir():
		logging.info("creating output folder")
		output.mkdir()

	logging.info("writing groups to \"{}\"".format(args.output_folder))

	np.savetxt(Path(args.output_folder, "parts.txt"),
		np.array(new_names, dtype=[("id", np.int32), ("name", "U255")]),
		fmt="%d %s")

	np.savetxt(Path(args.output_folder, "part_locs.txt"), new_parts.reshape(-1),
		fmt="%s %d %d %d %d")

main(parser.parse_args())

"""
Examples:

#### CUB-200-2011 ###
> cat ~/Data/DATASETS/birds/cub200/groups.json
# {
#   "head": [
#     "beak",
#     "crown",
#     "forehead",
#     "left eye",
#     "nape",
#     "right eye",
#     "throat"
#   ],
#   "body": [
#     "back",
#     "belly",
#     "breast",
#     "left wing",
#     "right wing"
#   ],
#   "tail": [
#     "tail"
#   ],
#   "legs": [
#     "left leg",
#     "right leg"
#   ]
# }
> cd ~/Data/DATASETS/birds/cub200/
> <path_to_script>.py GT/parts/{part_locs,parts}.txt groups.json GT2/parts --id_shift 1

#### NA-Birds ###
> cat ~/Data/DATASETS/nabirds/cub200/groups.json
# {
#   "head": [
#     "bill",
#     "crown",
#     "nape",
#     "left eye",
#     "right eye"
#   ],
#   "body": [
#     "back",
#     "belly",
#     "breast",
#     "left wing",
#     "right wing"
#   ],
#   "tail": [
#     "tail"
#   ]
# }
> cd ~/Data/DATASETS/birds/nabirds/
> <path_to_script>.py GT/parts/{part_locs,parts}.txt groups.json GT2/parts

"""