Browse Source

fixed an import. added sacred plotter class

Dimitri Korsch 6 years ago
parent
commit
c7db8810ec

+ 1 - 1
cvfinetune/training/trainer/alpha_pooling.py

@@ -1,4 +1,4 @@
-from .sacred import SacredReport
+from .sacred import SacredTrainer
 from .base import _is_adam, default_intervals
 
 from chainer.training import extensions

+ 0 - 0
cvfinetune/utils/__init__.py


+ 101 - 0
cvfinetune/utils/sacred_plotter.py

@@ -0,0 +1,101 @@
+import pymongo as pym
+import numpy as np
+import matplotlib.pyplot as plt
+
+from matplotlib.gridspec import GridSpec
+
+class SacredPlotter(object):
+
+	@staticmethod
+	def auth_url(creds, host="localhost"):
+		url = "mongodb://{user}:{password}@{host}:27017/{db_name}?authSource=admin".format(
+			host=host, **creds)
+		return url
+
+	def __init__(self, creds):
+		super(SacredPlotter, self).__init__()
+		self.client = pym.MongoClient(SacredPlotter.auth_url(creds))
+
+		self.db = self.client[creds["db_name"]]
+		self.metrics = self.db["metrics"]
+		self.runs = self.db["runs"]
+
+	def get_values(self, metric, query):
+
+		values = []
+		runs = list(self.runs.find(query))
+
+		for run in runs:
+
+			accuracies = self.metrics.find_one(dict(
+				name=metric,
+				run_id=run["_id"],
+			))
+			if accuracies is None or accuracies["values"] is None:
+				continue
+
+			values.append(accuracies["values"][-1])
+
+		return values
+
+	def plot(self,
+		metrics,
+		setups,
+		query_factory,
+		setup_to_label,
+		include_running=False,
+		metrics_key="val/main/",
+		**plot_kwargs):
+
+		"""
+			Arguments:
+				- metrics: defines which metrics to plot
+
+				- setups: defines different setups, that
+					will be compared
+
+				- query_factory: callable; creates from the setup
+					specification the pymongo query
+
+				- setup_to_label: callable; converts a setup
+					specification into a readable label
+
+				- include_running: whether running experiments should
+					be included or not
+
+				- metrics_key: prefix that will be appended to each
+					metric name
+
+		"""
+
+		n_metrics = len(metrics)
+		n_cols = int(np.ceil(np.sqrt(n_metrics)))
+		n_rows = int(np.ceil(n_metrics / n_cols))
+
+		grid = GridSpec(n_rows, n_cols)
+		axs = []
+		for i, metric in enumerate(metrics):
+
+			res = []
+
+			for setup in setups:
+				query = query_factory(setup)
+
+				if not include_running:
+					query["status"] = {"$ne": "RUNNING"}
+
+				res.append((setup, self.get_values(f"{metrics_key}{metric}", query)))
+
+			row, col = np.unravel_index(i, (n_rows, n_cols))
+
+
+			ax = plt.subplot(grid[row, col])
+			labels, values = zip(*[(setup_to_label(setup, vals), vals) for setup, vals in res if vals])
+			ax.boxplot(values, labels=labels, **plot_kwargs)
+			ax.set_title(f"Metric: {metric}")
+
+			axs.append(ax)
+
+		return axs
+
+