train_bow.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import argparse
  2. import os
  3. import cv2 as cv
  4. import numpy as np
  5. from tqdm import tqdm
  6. from py.Dataset import Dataset
  7. from py.Session import SessionImage
  8. def dense_keypoints(img, step=30, off=(15, 12)):
  9. border_dist = (step + 1) // 2
  10. return [cv.KeyPoint(x, y, step) for y in range(border_dist + off[0], img.shape[0] - border_dist, step)
  11. for x in range(border_dist + off[1], img.shape[1] - border_dist, step)]
  12. def extract_descriptors(images: list[SessionImage]):
  13. sift = cv.SIFT_create()
  14. dscs = []
  15. for image in tqdm(images):
  16. img = image.read_opencv(gray=True)
  17. kp = dense_keypoints(img)
  18. kp, des = sift.compute(img, kp)
  19. dscs.append(des)
  20. return np.array(dscs)
  21. def generate_dictionary(dscs, dictionary_size):
  22. # dictionary size = number of clusters
  23. BOW = cv.BOWKMeansTrainer(dictionary_size)
  24. for dsc in dscs:
  25. BOW.add(dsc)
  26. dictionary = BOW.cluster()
  27. return dictionary
  28. if __name__ == "__main__":
  29. parser = argparse.ArgumentParser(description="BOW train script")
  30. parser.add_argument("dataset_dir", type=str, help="Directory of the dataset containing all session folders")
  31. parser.add_argument("session_name", type=str, help="Name of the session to use for Lapse images (e.g. marten_01)")
  32. parser.add_argument("--clusters", type=int, help="Number of clusters / BOW vocabulary size", default=1024)
  33. args = parser.parse_args()
  34. ds = Dataset(args.dataset_dir)
  35. session = ds.create_session(args.session_name)
  36. save_dir = f"./bow_train_NoBackup/{session.name}"
  37. # Lapse DSIFT descriptors
  38. lapse_dscs_file = os.path.join(save_dir, "lapse_dscs.npy")
  39. if os.path.isfile(lapse_dscs_file):
  40. print(f"{lapse_dscs_file} already exists, loading lapse descriptor from file...")
  41. lapse_dscs = np.load(lapse_dscs_file)
  42. else:
  43. print("Extracting lapse descriptors...")
  44. lapse_dscs = extract_descriptors(list(session.generate_lapse_images()))
  45. os.makedirs(save_dir, exist_ok=True)
  46. np.save(lapse_dscs_file, lapse_dscs)
  47. # BOW dictionary
  48. dictionary_file = os.path.join(save_dir, f"bow_dict_{args.clusters}.npy")
  49. if os.path.isfile(dictionary_file):
  50. print(f"{dictionary_file} already exists, loading BOW dictionary from file...")
  51. dictionary = np.load(dictionary_file)
  52. else:
  53. print(f"Creating BOW vocabulary with {args.clusters} clusters...")
  54. dictionary = generate_dictionary(lapse_dscs, args.clusters)
  55. np.save(dictionary_file, dictionary)
  56. print("Complete!")