eval_bow.py 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright (c) 2023 Felix Kleinsteuber and Computer Vision Group, Friedrich Schiller University Jena
  2. # Approach 3: Local features
  3. # This script is used for calculating BOW features of Motion images
  4. # using a BOW vocabulary.
  5. # See train_bow.py for training.
  6. import argparse
  7. import os
  8. import numpy as np
  9. from sklearn import svm
  10. from tqdm import tqdm
  11. from py.Dataset import Dataset
  12. from py.LocalFeatures import generate_bow_features
  13. def main():
  14. parser = argparse.ArgumentParser(description="BOW train script")
  15. parser.add_argument("dataset_dir", type=str, help="Directory of the dataset containing all session folders")
  16. parser.add_argument("session_name", type=str, help="Name of the session to use for Lapse images (e.g. marten_01)")
  17. parser.add_argument("--clusters", type=int, help="Number of clusters / BOW vocabulary size", default=1024)
  18. parser.add_argument("--step_size", type=int, help="DSIFT keypoint step size. Smaller step size = more keypoints.", default=30)
  19. parser.add_argument("--keypoint_size", type=int, help="DSIFT keypoint size. Defaults to step_size.", default=-1)
  20. parser.add_argument("--include_motion", action="store_true", help="Include motion images for training.")
  21. parser.add_argument("--random_prototypes", action="store_true", help="Pick random prototype vectors instead of doing kmeans.")
  22. args = parser.parse_args()
  23. if args.keypoint_size <= 0:
  24. args.keypoint_size = args.step_size
  25. print(f"Using keypoint size {args.keypoint_size} with step size {args.step_size}.")
  26. ds = Dataset(args.dataset_dir)
  27. session = ds.create_session(args.session_name)
  28. save_dir = f"./bow_train_NoBackup/{session.name}"
  29. suffix = ""
  30. if args.include_motion:
  31. suffix += "_motion"
  32. print("Including motion data for prototype selection!")
  33. if args.random_prototypes:
  34. suffix += "_random"
  35. print("Picking random prototypes instead of using kmeans!")
  36. dictionary_file = os.path.join(save_dir, f"bow_dict_{args.step_size}_{args.keypoint_size}_{args.clusters}{suffix}.npy")
  37. train_feat_file = os.path.join(save_dir, f"bow_train_{args.step_size}_{args.keypoint_size}_{args.clusters}{suffix}.npy")
  38. eval_file = os.path.join(save_dir, f"bow_eval_{args.step_size}_{args.keypoint_size}_{args.clusters}{suffix}.csv")
  39. if not os.path.isfile(dictionary_file):
  40. print(f"ERROR: BOW dictionary missing! ({dictionary_file})")
  41. elif not os.path.isfile(train_feat_file):
  42. print(f"ERROR: Train data file missing! ({train_feat_file})")
  43. elif os.path.isfile(eval_file):
  44. print(f"ERROR: Eval file already exists! ({eval_file})")
  45. else:
  46. print(f"Loading dictionary from {dictionary_file}...")
  47. dictionaries = np.load(dictionary_file)
  48. print(f"Shape of dictionaries: {dictionaries.shape}") # (num_dicts, dict_size, 128)
  49. assert len(dictionaries.shape) == 3 and dictionaries.shape[2] == 128
  50. print(f"Loading training data from {train_feat_file}...")
  51. train_data = np.load(train_feat_file)
  52. print(f"Shape of training data: {train_data.shape}") # (num_train_images, num_dicts, 1, dict_size)
  53. assert len(train_data.shape) == 4
  54. assert train_data.shape[1] == dictionaries.shape[0]
  55. assert train_data.shape[2] == 1
  56. assert train_data.shape[3] == dictionaries.shape[1]
  57. print(f"Fitting {dictionaries.shape[0]} one-class SVMs...")
  58. clfs = [svm.OneClassSVM().fit(train_data[:,i,0,:].squeeze()) for i in tqdm(range(dictionaries.shape[0]))]
  59. print("Evaluating...")
  60. with open(eval_file, "a+") as f:
  61. for filename, feats in generate_bow_features(list(session.generate_motion_images()), dictionaries, kp_step=args.step_size, kp_size=args.keypoint_size):
  62. ys = [clf.decision_function(feat)[0] for clf, feat in zip(clfs, feats)]
  63. ys_out = ",".join([str(y) for y in ys])
  64. f.write(f"{filename},{ys_out}\n")
  65. f.flush()
  66. print("Complete!")
  67. if __name__ == "__main__":
  68. main()