123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- from turtle import pd
- from warnings import warn
- import numpy as np
- import pandas as pd
- class DatasetStatistics:
- def __init__(self, stats_dict: dict = None, load_from_file: str = None):
- """Create a new statistics instance. The statistics can either be defined
- as a dict (see Dataset#create_statistics()) or loaded from a file.
- Args:
- stats_dict (dict, optional): Dict with statistics. Defaults to None.
- load_from_file (str, optional): File saved with DatasetStatistics#save(). Defaults to None.
- Raises:
- ValueError: If neither stats_dict nor load_from_file is set.
- """
- self.stats = {}
- if stats_dict is not None:
- self.stats = stats_dict
- elif load_from_file is not None:
- self.load(load_from_file)
- else:
- raise ValueError("Please provide 'stats_dict' or 'load_from_file'.")
- self.__update_dataframe()
- def __update_dataframe(self):
- self.df = pd.DataFrame.from_dict(self.stats).transpose()
- def add_total_row(self, row_name = "Z_Total") -> "DatasetStatistics":
- if row_name in self.stats:
- warn(f"{row_name} is already a defined row")
- return self
- self.stats[row_name] = {}
- # iterate over all folders and subfolders
- for folder in self.stats:
- if folder != row_name:
- for subfolder in self.stats[folder]:
- # add to total row
- if subfolder in self.stats[row_name]:
- self.stats[row_name][subfolder] += self.stats[folder][subfolder]
- else:
- self.stats[row_name][subfolder] = self.stats[folder][subfolder]
- self.__update_dataframe()
- return self
-
- def save(self, filename = "dataset_stats.npy"):
- np.save(filename, self.stats)
- print(f"Saved to {filename}.")
-
- def load(self, filename = "dataset_stats.npy"):
- self.stats = np.load(filename, allow_pickle=True).tolist()
- self.__update_dataframe()
- print(f"Loaded from {filename}.")
- def view(self, col_order = ["Lapse", "Motion", "Full", "Total"]) -> pd.DataFrame:
- return self.df.sort_index()[col_order]
-
- def plot_sessions(self, cols = ["Lapse", "Motion", "Full"], figsize = (20, 10), style = {"width": 2}, exclude_last_row = False):
- df = self.df[cols]
- # Plot lapse, motion, full columns without the last row (Z_Total)
- if exclude_last_row:
- df = df.iloc[:-1]
- return df.plot.bar(figsize=figsize, style=style)
|