DatasetStatistics.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright (c) 2023 Felix Kleinsteuber and Computer Vision Group, Friedrich Schiller University Jena
  2. from turtle import pd
  3. from warnings import warn
  4. import numpy as np
  5. import pandas as pd
  6. # helper for accumulating, saving, loading, and displaying dataset statistics
  7. class DatasetStatistics:
  8. def __init__(self, stats_dict: dict = None, load_from_file: str = None):
  9. """Create a new statistics instance. The statistics can either be defined
  10. as a dict (see Dataset#create_statistics()) or loaded from a file.
  11. Args:
  12. stats_dict (dict, optional): Dict with statistics. Defaults to None.
  13. load_from_file (str, optional): File saved with DatasetStatistics#save(). Defaults to None.
  14. Raises:
  15. ValueError: If neither stats_dict nor load_from_file is set.
  16. """
  17. self.stats = {}
  18. if stats_dict is not None:
  19. self.stats = stats_dict
  20. elif load_from_file is not None:
  21. self.load(load_from_file)
  22. else:
  23. raise ValueError("Please provide 'stats_dict' or 'load_from_file'.")
  24. self.__update_dataframe()
  25. def __update_dataframe(self):
  26. self.df = pd.DataFrame.from_dict(self.stats).transpose()
  27. def add_total_row(self, row_name = "Z_Total") -> "DatasetStatistics":
  28. """Add a row to the pandas dataframe with totals of all columns. Should only be called once.
  29. Args:
  30. row_name (str, optional): Name of the new row. Defaults to "Z_Total".
  31. Returns:
  32. DatasetStatistics: self
  33. """
  34. if row_name in self.stats:
  35. warn(f"{row_name} is already a defined row")
  36. return self
  37. self.stats[row_name] = {}
  38. # iterate over all folders and subfolders
  39. for folder in self.stats:
  40. if folder != row_name:
  41. for subfolder in self.stats[folder]:
  42. # add to total row
  43. if subfolder in self.stats[row_name]:
  44. self.stats[row_name][subfolder] += self.stats[folder][subfolder]
  45. else:
  46. self.stats[row_name][subfolder] = self.stats[folder][subfolder]
  47. self.__update_dataframe()
  48. return self
  49. def save(self, filename = "dataset_stats.npy"):
  50. """Save statistics stored in this instance to a file using numpy.
  51. Args:
  52. filename (str, optional): Target file name. Defaults to "dataset_stats.npy".
  53. """
  54. np.save(filename, self.stats)
  55. print(f"Saved to {filename}.")
  56. def load(self, filename = "dataset_stats.npy"):
  57. """Load statistics from a file using numpy.
  58. Args:
  59. filename (str, optional): Target file name. Defaults to "dataset_stats.npy".
  60. """
  61. self.stats = np.load(filename, allow_pickle=True).tolist()
  62. self.__update_dataframe()
  63. print(f"Loaded from {filename}.")
  64. def view(self, col_order = ["Lapse", "Motion", "Full", "Total"]) -> pd.DataFrame:
  65. """Display the statistics dataframe.
  66. Args:
  67. col_order (list, optional): Order of columns. Defaults to ["Lapse", "Motion", "Full", "Total"].
  68. Returns:
  69. pd.DataFrame: data frame
  70. """
  71. return self.df.sort_index()[col_order]
  72. def plot_sessions(self, cols = ["Lapse", "Motion", "Full"], figsize = (20, 10), style = {"width": 2}, exclude_last_row = False):
  73. """Plot the statistics dataframe as a bar plot.
  74. Args:
  75. cols (list, optional): Columns to include. Defaults to ["Lapse", "Motion", "Full"].
  76. figsize (tuple, optional): Plot size. Defaults to (20, 10).
  77. style (dict, optional): Additional style arguments. Defaults to {"width": 2}.
  78. exclude_last_row (bool, optional): If True, the last row will not be plotted. Defaults to False.
  79. Returns:
  80. _type_: _description_
  81. """
  82. df = self.df[cols]
  83. # Plot lapse, motion, full columns without the last row (Z_Total)
  84. if exclude_last_row:
  85. df = df.iloc[:-1]
  86. return df.plot.bar(figsize=figsize, style=style)