DatasetStatistics.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from turtle import pd
  2. from warnings import warn
  3. import numpy as np
  4. import pandas as pd
  5. class DatasetStatistics:
  6. def __init__(self, stats_dict: dict = None, load_from_file: str = None):
  7. """Create a new statistics instance. The statistics can either be defined
  8. as a dict (see Dataset#create_statistics()) or loaded from a file.
  9. Args:
  10. stats_dict (dict, optional): Dict with statistics. Defaults to None.
  11. load_from_file (str, optional): File saved with DatasetStatistics#save(). Defaults to None.
  12. Raises:
  13. ValueError: If neither stats_dict nor load_from_file is set.
  14. """
  15. self.stats = {}
  16. if stats_dict is not None:
  17. self.stats = stats_dict
  18. elif load_from_file is not None:
  19. self.load(load_from_file)
  20. else:
  21. raise ValueError("Please provide 'stats_dict' or 'load_from_file'.")
  22. self.__update_dataframe()
  23. def __update_dataframe(self):
  24. self.df = pd.DataFrame.from_dict(self.stats).transpose()
  25. def add_total_row(self, row_name = "Z_Total") -> "DatasetStatistics":
  26. if row_name in self.stats:
  27. warn(f"{row_name} is already a defined row")
  28. return self
  29. self.stats[row_name] = {}
  30. # iterate over all folders and subfolders
  31. for folder in self.stats:
  32. if folder != row_name:
  33. for subfolder in self.stats[folder]:
  34. # add to total row
  35. if subfolder in self.stats[row_name]:
  36. self.stats[row_name][subfolder] += self.stats[folder][subfolder]
  37. else:
  38. self.stats[row_name][subfolder] = self.stats[folder][subfolder]
  39. self.__update_dataframe()
  40. return self
  41. def save(self, filename = "dataset_stats.npy"):
  42. np.save(filename, self.stats)
  43. print(f"Saved to {filename}.")
  44. def load(self, filename = "dataset_stats.npy"):
  45. self.stats = np.load(filename, allow_pickle=True).tolist()
  46. self.__update_dataframe()
  47. print(f"Loaded from {filename}.")
  48. def view(self, col_order = ["Lapse", "Motion", "Full", "Total"]) -> pd.DataFrame:
  49. return self.df.sort_index()[col_order]
  50. def plot_sessions(self, cols = ["Lapse", "Motion", "Full"], figsize = (20, 10), style = {"width": 2}, exclude_last_row = False):
  51. df = self.df[cols]
  52. # Plot lapse, motion, full columns without the last row (Z_Total)
  53. if exclude_last_row:
  54. df = df.iloc[:-1]
  55. return df.plot.bar(figsize=figsize, style=style)