#!/usr/bin/env python if __name__ != '__main__': raise Exception("Do not import me!") import re import logging import numpy as np from tqdm import tqdm from os.path import basename, join from cvargparse import Arg from utils import parser from utils import imaging def _match(content, regex): match = regex.match(content) if match is not None: return int(match.group(1)) def _label_from_foldername(folder_name, label_regex=re.compile(r"^(\d+)\..*$")): return _match(basename(folder_name), label_regex) def folder_sorting_key(entry): folder_name, _ = entry label = _label_from_foldername(folder_name) return -1 if label is None else label def file_sorting_key(fname, im_id_regex=re.compile(r"^.*\_(\d+)\..*$")): idx = _match(fname, im_id_regex) return -1 if idx is None else idx def main(args): scores = np.loadtxt(args.scores, delimiter=",", skiprows=1, dtype=[ ("class", np.int), ("filename", np.dtype("U255")), ("score", np.float), ]) scores_above = scores["score"] > args.min_score logging.info(f"{scores_above.sum()} of {len(scores)} are above the score {args.min_score}") content = imaging.get_content(args.folder, args.extensions) content = sorted(content, key=folder_sorting_key) i = 0 rnd = np.random.RandomState(args.seed) labels = [] split = [] with open(args.images_file, "w") as ims_f: for folder_name, img_files in tqdm(content): label = _label_from_foldername(folder_name) img_files = sorted(img_files, key=file_sorting_key) cls_mask = scores["class"] == label assert len(cls_mask) != 0, f"Could not find scores for class {label}" ims_above_score = scores["filename"][np.logical_and(cls_mask, scores_above)] ims_above_score = set([basename(fname) for fname in ims_above_score]) for fname in img_files: if fname not in ims_above_score: continue print(i, join(basename(folder_name), fname), file=ims_f) labels.append(label) i += 1 n_files = len(ims_above_score) split.extend(rnd.choice(2, size=n_files, p=[args.ratio, 1-args.ratio])) np.savetxt(args.labels_file, labels, fmt="%d") np.savetxt(args.split_file, split, fmt="%d") logging.info(f"Created annotations for {i} samples") main(parser.parse_args([ Arg("scores"), Arg("--min_score", type=float, default=0.6), Arg("--images_file", default="images.txt"), Arg("--labels_file", default="labels.txt"), Arg("--split_file", default="tr_ID.txt"), Arg("--split_ratio", dest="ratio", type=float, default=.1), Arg("--seed", type=int, default=42), ]))