Dataset.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. from tqdm import tqdm
  3. from py.DatasetStatistics import DatasetStatistics
  4. from py.FileUtils import list_folders, list_jpegs_recursive, expected_subfolders, verify_expected_subfolders
  5. from py.Session import Session
  6. class Dataset:
  7. def __init__(self, base_path: str):
  8. self.base_path = base_path
  9. self.raw_sessions = []
  10. self.__parse_subdirectories()
  11. def __parse_subdirectories(self):
  12. self.raw_sessions = sorted(list_folders(self.base_path))
  13. # Verify every session contains the subfolders Motion, Lapse, Full
  14. for folder in self.raw_sessions:
  15. path = os.path.join(self.base_path, folder)
  16. verify_expected_subfolders(path)
  17. print(f"Found {len(self.raw_sessions)} sessions")
  18. def get_sessions(self) -> list:
  19. # cut off the first 33 characters (redundant)
  20. return [name[33:] for name in self.raw_sessions]
  21. def create_statistics(self) -> DatasetStatistics:
  22. counts = {}
  23. for folder in tqdm(self.raw_sessions):
  24. counts[folder[33:]] = {}
  25. counts[folder[33:]]["Total"] = 0
  26. for subfolder in expected_subfolders:
  27. path = os.path.join(self.base_path, folder, subfolder)
  28. numFiles = len(list_jpegs_recursive(path))
  29. counts[folder[33:]][subfolder] = numFiles
  30. counts[folder[33:]]["Total"] += numFiles
  31. return DatasetStatistics(counts)
  32. def create_session(self, session_name: str) -> Session:
  33. if session_name in self.raw_sessions:
  34. return Session(os.path.join(self.base_path, session_name))
  35. filtered = [s for s in self.raw_sessions if session_name.lower() in s.lower()]
  36. if len(filtered) == 0:
  37. raise ValueError(f"There are no sessions matching this name: {filtered}")
  38. elif len(filtered) > 1:
  39. raise ValueError(f"There are several sessions matching this name: {session_name}")
  40. return Session(os.path.join(self.base_path, filtered[0]))