DatasetStatistics.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from turtle import pd
  2. from warnings import warn
  3. import numpy as np
  4. import pandas as pd
  5. # helper for accumulating, saving, loading, and displaying dataset statistics
  6. class DatasetStatistics:
  7. def __init__(self, stats_dict: dict = None, load_from_file: str = None):
  8. """Create a new statistics instance. The statistics can either be defined
  9. as a dict (see Dataset#create_statistics()) or loaded from a file.
  10. Args:
  11. stats_dict (dict, optional): Dict with statistics. Defaults to None.
  12. load_from_file (str, optional): File saved with DatasetStatistics#save(). Defaults to None.
  13. Raises:
  14. ValueError: If neither stats_dict nor load_from_file is set.
  15. """
  16. self.stats = {}
  17. if stats_dict is not None:
  18. self.stats = stats_dict
  19. elif load_from_file is not None:
  20. self.load(load_from_file)
  21. else:
  22. raise ValueError("Please provide 'stats_dict' or 'load_from_file'.")
  23. self.__update_dataframe()
  24. def __update_dataframe(self):
  25. self.df = pd.DataFrame.from_dict(self.stats).transpose()
  26. def add_total_row(self, row_name = "Z_Total") -> "DatasetStatistics":
  27. """Add a row to the pandas dataframe with totals of all columns. Should only be called once.
  28. Args:
  29. row_name (str, optional): Name of the new row. Defaults to "Z_Total".
  30. Returns:
  31. DatasetStatistics: self
  32. """
  33. if row_name in self.stats:
  34. warn(f"{row_name} is already a defined row")
  35. return self
  36. self.stats[row_name] = {}
  37. # iterate over all folders and subfolders
  38. for folder in self.stats:
  39. if folder != row_name:
  40. for subfolder in self.stats[folder]:
  41. # add to total row
  42. if subfolder in self.stats[row_name]:
  43. self.stats[row_name][subfolder] += self.stats[folder][subfolder]
  44. else:
  45. self.stats[row_name][subfolder] = self.stats[folder][subfolder]
  46. self.__update_dataframe()
  47. return self
  48. def save(self, filename = "dataset_stats.npy"):
  49. """Save statistics stored in this instance to a file using numpy.
  50. Args:
  51. filename (str, optional): Target file name. Defaults to "dataset_stats.npy".
  52. """
  53. np.save(filename, self.stats)
  54. print(f"Saved to {filename}.")
  55. def load(self, filename = "dataset_stats.npy"):
  56. """Load statistics from a file using numpy.
  57. Args:
  58. filename (str, optional): Target file name. Defaults to "dataset_stats.npy".
  59. """
  60. self.stats = np.load(filename, allow_pickle=True).tolist()
  61. self.__update_dataframe()
  62. print(f"Loaded from {filename}.")
  63. def view(self, col_order = ["Lapse", "Motion", "Full", "Total"]) -> pd.DataFrame:
  64. """Display the statistics dataframe.
  65. Args:
  66. col_order (list, optional): Order of columns. Defaults to ["Lapse", "Motion", "Full", "Total"].
  67. Returns:
  68. pd.DataFrame: data frame
  69. """
  70. return self.df.sort_index()[col_order]
  71. def plot_sessions(self, cols = ["Lapse", "Motion", "Full"], figsize = (20, 10), style = {"width": 2}, exclude_last_row = False):
  72. """Plot the statistics dataframe as a bar plot.
  73. Args:
  74. cols (list, optional): Columns to include. Defaults to ["Lapse", "Motion", "Full"].
  75. figsize (tuple, optional): Plot size. Defaults to (20, 10).
  76. style (dict, optional): Additional style arguments. Defaults to {"width": 2}.
  77. exclude_last_row (bool, optional): If True, the last row will not be plotted. Defaults to False.
  78. Returns:
  79. _type_: _description_
  80. """
  81. df = self.df[cols]
  82. # Plot lapse, motion, full columns without the last row (Z_Total)
  83. if exclude_last_row:
  84. df = df.iloc[:-1]
  85. return df.plot.bar(figsize=figsize, style=style)