train_bow.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Approach 3: Local features
  2. # This script is used for generating a BOW vocabulary using
  3. # densely sampeled SIFT features on Lapse images.
  4. # See eval_bow.py for evaluation.
  5. import argparse
  6. import os
  7. import numpy as np
  8. from py.Dataset import Dataset
  9. from py.LocalFeatures import extract_descriptors, generate_dictionary_from_descriptors, generate_bow_features
  10. def main():
  11. parser = argparse.ArgumentParser(description="BOW train script")
  12. parser.add_argument("dataset_dir", type=str, help="Directory of the dataset containing all session folders")
  13. parser.add_argument("session_name", type=str, help="Name of the session to use for Lapse images (e.g. marten_01)")
  14. parser.add_argument("--clusters", type=int, help="Number of clusters / BOW vocabulary size", default=1024)
  15. parser.add_argument("--step_size", type=int, help="DSIFT keypoint step size. Smaller step size = more keypoints.", default=30)
  16. parser.add_argument("--keypoint_size", type=int, help="DSIFT keypoint size. Should be >= step_size.", default=60)
  17. args = parser.parse_args()
  18. ds = Dataset(args.dataset_dir)
  19. session = ds.create_session(args.session_name)
  20. save_dir = f"./bow_train_NoBackup/{session.name}"
  21. # Lapse DSIFT descriptors
  22. lapse_dscs_file = os.path.join(save_dir, f"lapse_dscs_{args.step_size}_{args.keypoint_size}.npy")
  23. dictionary_file = os.path.join(save_dir, f"bow_dict_{args.step_size}_{args.keypoint_size}_{args.clusters}.npy")
  24. train_feat_file = os.path.join(save_dir, f"bow_train_{args.step_size}_{args.keypoint_size}_{args.clusters}.npy")
  25. if os.path.isfile(lapse_dscs_file):
  26. if os.path.isfile(dictionary_file):
  27. # if dictionary file already exists, we don't need the lapse descriptors
  28. print(f"{lapse_dscs_file} already exists, skipping lapse descriptor extraction...")
  29. else:
  30. print(f"{lapse_dscs_file} already exists, loading lapse descriptor from file...")
  31. lapse_dscs = np.load(lapse_dscs_file)
  32. else:
  33. # Step 1 - extract dense SIFT descriptors
  34. print("Extracting lapse descriptors...")
  35. lapse_dscs = extract_descriptors(list(session.generate_lapse_images()), kp_step=args.step_size, kp_size=args.keypoint_size)
  36. os.makedirs(save_dir, exist_ok=True)
  37. np.save(lapse_dscs_file, lapse_dscs)
  38. # BOW dictionary
  39. if os.path.isfile(dictionary_file):
  40. print(f"{dictionary_file} already exists, loading BOW dictionary from file...")
  41. dictionary = np.load(dictionary_file)
  42. else:
  43. # Step 2 - create BOW dictionary from Lapse SIFT descriptors
  44. print(f"Creating BOW vocabulary with {args.clusters} clusters...")
  45. dictionary = generate_dictionary_from_descriptors(lapse_dscs, args.clusters)
  46. np.save(dictionary_file, dictionary)
  47. # Extract Lapse BOW features using vocabulary (train data)
  48. if os.path.isfile(train_feat_file):
  49. print(f"{train_feat_file} already exists, skipping lapse BOW feature extraction...")
  50. else:
  51. # Step 3 - calculate training data (BOW features of Lapse images)
  52. print(f"Extracting BOW features from Lapse images...")
  53. features = [feat for _, feat in generate_bow_features(list(session.generate_lapse_images()), dictionary, kp_step=args.step_size, kp_size=args.keypoint_size)]
  54. np.save(train_feat_file, features)
  55. print("Complete!")
  56. if __name__ == "__main__":
  57. main()