train_bow.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) 2023 Felix Kleinsteuber and Computer Vision Group, Friedrich Schiller University Jena
  2. # Approach 3: Local features
  3. # This script is used for generating a BOW vocabulary using
  4. # densely sampeled SIFT features on Lapse images.
  5. # See eval_bow.py for evaluation.
  6. import argparse
  7. import os
  8. import numpy as np
  9. from timeit import default_timer as timer
  10. from datetime import timedelta
  11. from py.Dataset import Dataset
  12. from py.LocalFeatures import extract_descriptors, generate_dictionary_from_descriptors, generate_bow_features, pick_random_descriptors
  13. def main():
  14. parser = argparse.ArgumentParser(description="BOW train script")
  15. parser.add_argument("dataset_dir", type=str, help="Directory of the dataset containing all session folders")
  16. parser.add_argument("session_name", type=str, help="Name of the session to use for Lapse images (e.g. marten_01)")
  17. parser.add_argument("--clusters", type=int, help="Number of clusters / BOW vocabulary size", default=1024)
  18. parser.add_argument("--step_size", type=int, help="DSIFT keypoint step size. Smaller step size = more keypoints.", default=30)
  19. parser.add_argument("--keypoint_size", type=int, help="DSIFT keypoint size. Defaults to step_size.", default=-1)
  20. parser.add_argument("--include_motion", action="store_true", help="Include motion images for training.")
  21. parser.add_argument("--random_prototypes", action="store_true", help="Pick random prototype vectors instead of doing kmeans.")
  22. parser.add_argument("--num_vocabularies", type=int, help="Number of vocabularies to generate if random prototype choosing is enabled.", default=10)
  23. args = parser.parse_args()
  24. if args.keypoint_size <= 0:
  25. args.keypoint_size = args.step_size
  26. print(f"Using keypoint size {args.keypoint_size} with step size {args.step_size}.")
  27. ds = Dataset(args.dataset_dir)
  28. session = ds.create_session(args.session_name)
  29. save_dir = f"./bow_train_NoBackup/{session.name}"
  30. suffix = ""
  31. if args.include_motion:
  32. suffix += "_motion"
  33. print("Including motion data for prototype selection!")
  34. if args.random_prototypes:
  35. suffix += "_random"
  36. print("Picking random prototypes instead of using kmeans!")
  37. lapse_dscs_file = os.path.join(save_dir, f"lapse_dscs_{args.step_size}_{args.keypoint_size}.npy")
  38. motion_dscs_file = os.path.join(save_dir, f"motion_dscs_{args.step_size}_{args.keypoint_size}.npy")
  39. dictionary_file = os.path.join(save_dir, f"bow_dict_{args.step_size}_{args.keypoint_size}_{args.clusters}{suffix}.npy")
  40. train_feat_file = os.path.join(save_dir, f"bow_train_{args.step_size}_{args.keypoint_size}_{args.clusters}{suffix}.npy")
  41. # Lapse DSIFT descriptors
  42. if os.path.isfile(lapse_dscs_file):
  43. if os.path.isfile(dictionary_file):
  44. # if dictionary file already exists, we don't need the lapse descriptors
  45. print(f"{dictionary_file} already exists, skipping lapse descriptor extraction...")
  46. else:
  47. print(f"{lapse_dscs_file} already exists, loading lapse descriptors from file... ", end="")
  48. lapse_dscs = np.load(lapse_dscs_file)
  49. assert lapse_dscs.shape[-1] == 128
  50. lapse_dscs = lapse_dscs.reshape(-1, 128)
  51. print(f"Loaded {len(lapse_dscs)} lapse descriptors!")
  52. else:
  53. # Step 1 - extract dense SIFT descriptors
  54. print("Extracting lapse descriptors...")
  55. lapse_dscs = extract_descriptors(list(session.generate_lapse_images()), kp_step=args.step_size, kp_size=args.keypoint_size)
  56. os.makedirs(save_dir, exist_ok=True)
  57. np.save(lapse_dscs_file, lapse_dscs)
  58. # Motion DSIFT descriptors
  59. if args.include_motion:
  60. if os.path.isfile(motion_dscs_file):
  61. if os.path.isfile(dictionary_file):
  62. # if dictionary file already exists, we don't need the descriptors
  63. print(f"{dictionary_file} already exists, skipping motion descriptor extraction...")
  64. else:
  65. print(f"{motion_dscs_file} already exists, loading motion descriptors from file...", end="")
  66. motion_dscs = np.load(motion_dscs_file)
  67. assert motion_dscs.shape[-1] == 128
  68. motion_dscs = motion_dscs.reshape(-1, 128)
  69. print(f"Loaded {len(motion_dscs)} motion descriptors!")
  70. lapse_dscs = np.concatenate([lapse_dscs, motion_dscs])
  71. else:
  72. # Step 1b - extract dense SIFT descriptors from motion images
  73. print("Extracting motion descriptors...")
  74. motion_dscs = extract_descriptors(list(session.generate_motion_images()), kp_step=args.step_size, kp_size=args.keypoint_size)
  75. os.makedirs(save_dir, exist_ok=True)
  76. np.save(motion_dscs_file, motion_dscs)
  77. lapse_dscs = np.concatenate([lapse_dscs, motion_dscs])
  78. # BOW dictionary
  79. if os.path.isfile(dictionary_file):
  80. print(f"{dictionary_file} already exists, loading BOW dictionary from file...")
  81. dictionaries = np.load(dictionary_file)
  82. else:
  83. # Step 2 - create BOW dictionary from Lapse SIFT descriptors
  84. print(f"Creating BOW vocabulary with {args.clusters} clusters from {len(lapse_dscs)} descriptors...")
  85. start_time = timer()
  86. if args.random_prototypes:
  87. dictionaries = np.array([pick_random_descriptors(lapse_dscs, args.clusters) for i in range(args.num_vocabularies)])
  88. else:
  89. dictionaries = np.array([generate_dictionary_from_descriptors(lapse_dscs, args.clusters)])
  90. end_time = timer()
  91. delta_time = timedelta(seconds=end_time-start_time)
  92. print(f"Clustering took {delta_time}.")
  93. np.save(dictionary_file, dictionaries)
  94. # Extract Lapse BOW features using vocabulary (train data)
  95. if os.path.isfile(train_feat_file):
  96. print(f"{train_feat_file} already exists, skipping lapse BOW feature extraction...")
  97. else:
  98. # Step 3 - calculate training data (BOW features of Lapse images)
  99. print(f"Extracting BOW features from Lapse images...")
  100. features = [feat for _, feat in generate_bow_features(list(session.generate_lapse_images()), dictionaries, kp_step=args.step_size, kp_size=args.keypoint_size)]
  101. np.save(train_feat_file, features)
  102. print("Complete!")
  103. if __name__ == "__main__":
  104. main()