1
0

Dataset.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import os
  2. from tqdm import tqdm
  3. from py.DatasetStatistics import DatasetStatistics
  4. from py.FileUtils import list_folders, list_jpegs_recursive, expected_subfolders, verify_expected_subfolders
  5. from py.Session import Session
  6. # Represents the whole dataset consisting of multiple sessions. Can be used to get
  7. # session instances or to get an statistics instance.
  8. class Dataset:
  9. def __init__(self, base_path: str):
  10. """Create a new dataset instance.
  11. Args:
  12. base_path (str): Path to dataset, should contain subfolders for sessions.
  13. """
  14. self.base_path = base_path
  15. self.raw_sessions = []
  16. self.__parse_subdirectories()
  17. def __parse_subdirectories(self):
  18. self.raw_sessions = sorted(list_folders(self.base_path))
  19. # Verify every session contains the subfolders Motion, Lapse, Full
  20. for folder in self.raw_sessions:
  21. path = os.path.join(self.base_path, folder)
  22. verify_expected_subfolders(path)
  23. print(f"Found {len(self.raw_sessions)} sessions")
  24. def get_sessions(self) -> list:
  25. """Get names of all sessions (without prefixes).
  26. Returns:
  27. list of str: session names
  28. """
  29. # cut off the first 33 characters (redundant)
  30. return [name[33:] for name in self.raw_sessions]
  31. def create_statistics(self) -> DatasetStatistics:
  32. """Accumulate statistics over the dataset and return a new statistics instance.
  33. Returns:
  34. DatasetStatistics: statistics instance
  35. """
  36. counts = {}
  37. for folder in tqdm(self.raw_sessions):
  38. counts[folder[33:]] = {}
  39. counts[folder[33:]]["Total"] = 0
  40. for subfolder in expected_subfolders:
  41. path = os.path.join(self.base_path, folder, subfolder)
  42. numFiles = len(list_jpegs_recursive(path))
  43. counts[folder[33:]][subfolder] = numFiles
  44. counts[folder[33:]]["Total"] += numFiles
  45. return DatasetStatistics(counts)
  46. def create_session(self, session_name: str) -> Session:
  47. """Return a new session instance from the session name.
  48. Args:
  49. session_name (str): Session name, e.g. beaver_01. Not case-sensitive.
  50. Raises:
  51. ValueError: No or multiple sessions matching session name
  52. Returns:
  53. Session: Session instance
  54. """
  55. if session_name in self.raw_sessions:
  56. return Session(os.path.join(self.base_path, session_name))
  57. filtered = [s for s in self.raw_sessions if session_name.lower() in s.lower()]
  58. if len(filtered) == 0:
  59. raise ValueError(f"There are no sessions matching this name: {filtered}")
  60. elif len(filtered) > 1:
  61. raise ValueError(f"There are several sessions matching this name: {session_name}")
  62. return Session(os.path.join(self.base_path, filtered[0]))