create_annotations.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import re
  4. import logging
  5. import numpy as np
  6. from tqdm import tqdm
  7. from os.path import basename, join
  8. from cvargparse import Arg
  9. from utils import parser
  10. from utils import imaging
  11. def _match(content, regex):
  12. match = regex.match(content)
  13. if match is not None:
  14. return int(match.group(1))
  15. def _label_from_foldername(folder_name, label_regex=re.compile(r"^(\d+)\..*$")):
  16. return _match(basename(folder_name), label_regex)
  17. def folder_sorting_key(entry):
  18. folder_name, _ = entry
  19. label = _label_from_foldername(folder_name)
  20. return -1 if label is None else label
  21. def file_sorting_key(fname, im_id_regex=re.compile(r"^.*\_(\d+)\..*$")):
  22. idx = _match(fname, im_id_regex)
  23. return -1 if idx is None else idx
  24. def main(args):
  25. scores = np.loadtxt(args.scores, delimiter=",", skiprows=1, dtype=[
  26. ("class", np.int),
  27. ("filename", np.dtype("U255")),
  28. ("score", np.float),
  29. ])
  30. scores_above = scores["score"] > args.min_score
  31. logging.info(f"{scores_above.sum()} of {len(scores)} are above the score {args.min_score}")
  32. content = imaging.get_content(args.folder, args.extensions)
  33. content = sorted(content, key=folder_sorting_key)
  34. i = 0
  35. rnd = np.random.RandomState(args.seed)
  36. labels = []
  37. split = []
  38. with open(args.images_file, "w") as ims_f:
  39. for folder_name, img_files in tqdm(content):
  40. label = _label_from_foldername(folder_name)
  41. img_files = sorted(img_files, key=file_sorting_key)
  42. cls_mask = scores["class"] == label
  43. assert len(cls_mask) != 0, f"Could not find scores for class {label}"
  44. ims_above_score = scores["filename"][np.logical_and(cls_mask, scores_above)]
  45. ims_above_score = set([basename(fname) for fname in ims_above_score])
  46. for fname in img_files:
  47. if fname not in ims_above_score:
  48. continue
  49. print(i, join(basename(folder_name), fname), file=ims_f)
  50. labels.append(label)
  51. i += 1
  52. n_files = len(ims_above_score)
  53. split.extend(rnd.choice(2, size=n_files, p=[args.ratio, 1-args.ratio]))
  54. np.savetxt(args.labels_file, labels, fmt="%d")
  55. np.savetxt(args.split_file, split, fmt="%d")
  56. logging.info(f"Created annotations for {i} samples")
  57. main(parser.parse_args([
  58. Arg("scores"),
  59. Arg("--min_score", type=float, default=0.6),
  60. Arg("--images_file", default="images.txt"),
  61. Arg("--labels_file", default="labels.txt"),
  62. Arg("--split_file", default="tr_ID.txt"),
  63. Arg("--split_ratio", dest="ratio", type=float, default=.1),
  64. Arg("--seed", type=int, default=42),
  65. ]))