Dataset.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright (c) 2023 Felix Kleinsteuber and Computer Vision Group, Friedrich Schiller University Jena
  2. import os
  3. from tqdm import tqdm
  4. from py.DatasetStatistics import DatasetStatistics
  5. from py.FileUtils import list_folders, list_jpegs_recursive, expected_subfolders, verify_expected_subfolders
  6. from py.Session import Session
  7. # Represents the whole dataset consisting of multiple sessions. Can be used to get
  8. # session instances or to get an statistics instance.
  9. class Dataset:
  10. def __init__(self, base_path: str):
  11. """Create a new dataset instance.
  12. Args:
  13. base_path (str): Path to dataset, should contain subfolders for sessions.
  14. """
  15. self.base_path = base_path
  16. self.raw_sessions = []
  17. self.__parse_subdirectories()
  18. def __parse_subdirectories(self):
  19. self.raw_sessions = sorted(list_folders(self.base_path))
  20. # Verify every session contains the subfolders Motion, Lapse, Full
  21. for folder in self.raw_sessions:
  22. path = os.path.join(self.base_path, folder)
  23. verify_expected_subfolders(path)
  24. print(f"Found {len(self.raw_sessions)} sessions")
  25. def get_sessions(self) -> list:
  26. """Get names of all sessions (without prefixes).
  27. Returns:
  28. list of str: session names
  29. """
  30. # cut off the first 33 characters (redundant)
  31. return [name[33:] for name in self.raw_sessions]
  32. def create_statistics(self) -> DatasetStatistics:
  33. """Accumulate statistics over the dataset and return a new statistics instance.
  34. Returns:
  35. DatasetStatistics: statistics instance
  36. """
  37. counts = {}
  38. for folder in tqdm(self.raw_sessions):
  39. counts[folder[33:]] = {}
  40. counts[folder[33:]]["Total"] = 0
  41. for subfolder in expected_subfolders:
  42. path = os.path.join(self.base_path, folder, subfolder)
  43. numFiles = len(list_jpegs_recursive(path))
  44. counts[folder[33:]][subfolder] = numFiles
  45. counts[folder[33:]]["Total"] += numFiles
  46. return DatasetStatistics(counts)
  47. def create_session(self, session_name: str) -> Session:
  48. """Return a new session instance from the session name.
  49. Args:
  50. session_name (str): Session name, e.g. beaver_01. Not case-sensitive.
  51. Raises:
  52. ValueError: No or multiple sessions matching session name
  53. Returns:
  54. Session: Session instance
  55. """
  56. if session_name in self.raw_sessions:
  57. return Session(os.path.join(self.base_path, session_name))
  58. filtered = [s for s in self.raw_sessions if session_name.lower() in s.lower()]
  59. if len(filtered) == 0:
  60. raise ValueError(f"There are no sessions matching this name: {filtered}")
  61. elif len(filtered) > 1:
  62. raise ValueError(f"There are several sessions matching this name: {session_name}")
  63. return Session(os.path.join(self.base_path, filtered[0]))