eval_bow.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import argparse
  2. import os
  3. import numpy as np
  4. from sklearn import svm
  5. from py.Dataset import Dataset
  6. from py.LocalFeatures import generate_bow_features
  7. def main():
  8. parser = argparse.ArgumentParser(description="BOW train script")
  9. parser.add_argument("dataset_dir", type=str, help="Directory of the dataset containing all session folders")
  10. parser.add_argument("session_name", type=str, help="Name of the session to use for Lapse images (e.g. marten_01)")
  11. parser.add_argument("--clusters", type=int, help="Number of clusters / BOW vocabulary size", default=1024)
  12. args = parser.parse_args()
  13. ds = Dataset(args.dataset_dir)
  14. session = ds.create_session(args.session_name)
  15. save_dir = f"./bow_train_NoBackup/{session.name}"
  16. # Lapse DSIFT descriptors
  17. dictionary_file = os.path.join(save_dir, f"bow_dict_{args.clusters}.npy")
  18. train_feat_file = os.path.join(save_dir, f"bow_train_{args.clusters}.npy")
  19. eval_file = os.path.join(save_dir, f"bow_eval_{args.clusters}.csv")
  20. if not os.path.isfile(dictionary_file):
  21. print(f"ERROR: BOW dictionary missing! ({dictionary_file})")
  22. elif not os.path.isfile(train_feat_file):
  23. print(f"ERROR: Train data file missing! ({train_feat_file})")
  24. elif os.path.isfile(eval_file):
  25. print(f"ERROR: Eval file already exists! ({eval_file})")
  26. else:
  27. print(f"Loading dictionary from {dictionary_file}...")
  28. dictionary = np.load(dictionary_file)
  29. print(f"Loading training data from {train_feat_file}...")
  30. train_data = np.load(train_feat_file).squeeze()
  31. print(f"Fitting one-class SVM...")
  32. clf = svm.OneClassSVM().fit(train_data)
  33. print("Evaluating...")
  34. with open(eval_file, "a+") as f:
  35. for filename, feat in generate_bow_features(list(session.generate_motion_images()), dictionary):
  36. y = clf.decision_function(feat)[0]
  37. f.write(f"{filename},{y}\n")
  38. f.flush()
  39. print("Complete!")
  40. if __name__ == "__main__":
  41. main()