eval_bow.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import argparse
  2. import os
  3. import numpy as np
  4. from sklearn import svm
  5. from py.Dataset import Dataset
  6. from py.LocalFeatures import generate_bow_features
  7. def main():
  8. parser = argparse.ArgumentParser(description="BOW train script")
  9. parser.add_argument("dataset_dir", type=str, help="Directory of the dataset containing all session folders")
  10. parser.add_argument("session_name", type=str, help="Name of the session to use for Lapse images (e.g. marten_01)")
  11. parser.add_argument("--clusters", type=int, help="Number of clusters / BOW vocabulary size", default=1024)
  12. parser.add_argument("--step_size", type=int, help="DSIFT keypoint step size. Smaller step size = more keypoints.", default=30)
  13. args = parser.parse_args()
  14. ds = Dataset(args.dataset_dir)
  15. session = ds.create_session(args.session_name)
  16. save_dir = f"./bow_train_NoBackup/{session.name}"
  17. # Lapse DSIFT descriptors
  18. dictionary_file = os.path.join(save_dir, f"bow_dict_{args.step_size}_{args.clusters}.npy")
  19. train_feat_file = os.path.join(save_dir, f"bow_train_{args.step_size}_{args.clusters}.npy")
  20. eval_file = os.path.join(save_dir, f"bow_eval_{args.step_size}_{args.clusters}.csv")
  21. if not os.path.isfile(dictionary_file):
  22. print(f"ERROR: BOW dictionary missing! ({dictionary_file})")
  23. elif not os.path.isfile(train_feat_file):
  24. print(f"ERROR: Train data file missing! ({train_feat_file})")
  25. elif os.path.isfile(eval_file):
  26. print(f"ERROR: Eval file already exists! ({eval_file})")
  27. else:
  28. print(f"Loading dictionary from {dictionary_file}...")
  29. dictionary = np.load(dictionary_file)
  30. print(f"Loading training data from {train_feat_file}...")
  31. train_data = np.load(train_feat_file).squeeze()
  32. print(f"Fitting one-class SVM...")
  33. clf = svm.OneClassSVM().fit(train_data)
  34. print("Evaluating...")
  35. with open(eval_file, "a+") as f:
  36. for filename, feat in generate_bow_features(list(session.generate_motion_images()), dictionary, kp_step=args.step_size):
  37. y = clf.decision_function(feat)[0]
  38. f.write(f"{filename},{y}\n")
  39. f.flush()
  40. print("Complete!")
  41. if __name__ == "__main__":
  42. main()