소스 검색

added some error handling and start date filtering

Dimitri Korsch 6 년 전
부모
커밋
3c257c987c
2개의 변경된 파일41개의 추가작업 그리고 7개의 파일을 삭제
  1. 1 1
      cvfinetune/__init__.py
  2. 40 6
      cvfinetune/utils/sacred_plotter.py

+ 1 - 1
cvfinetune/__init__.py

@@ -1 +1 @@
-__version__ = "0.2.6"
+__version__ = "0.3.0"

+ 40 - 6
cvfinetune/utils/sacred_plotter.py

@@ -1,9 +1,28 @@
 import pymongo as pym
 import numpy as np
 import matplotlib.pyplot as plt
+import warnings
+from datetime import datetime
 
 from matplotlib.gridspec import GridSpec
 
+def parse_date(date, fmt="{:d}-{:02d}-{:02d} 00:00:00"):
+	if isinstance(date, str):
+		return datetime.fromisoformat(date)
+
+	elif isinstance(date, (tuple, list)):
+		return datetime.fromisoformat(fmt.format(*date))
+
+	elif date is None:
+		return None
+
+	else:
+		warnings.warn("Unsupported date format: {date}. Date parsing skipped".format(date))
+
+def query_to_str(query):
+	query_list = [f"{key}={value}" for key, value in query.items()]
+	return "{{{}}}".format(", ".join(query_list))
+
 class SacredPlotter(object):
 
 	@staticmethod
@@ -44,6 +63,7 @@ class SacredPlotter(object):
 		query_factory,
 		setup_to_label,
 		include_running=False,
+		start_time=None,
 		metrics_key="val/main/",
 		**plot_kwargs):
 
@@ -63,6 +83,9 @@ class SacredPlotter(object):
 				- include_running: whether running experiments should
 					be included or not
 
+				- start_time: time (starting of an experiment) from which
+					the experiments should be considered
+
 				- metrics_key: prefix that will be appended to each
 					metric name
 
@@ -84,17 +107,28 @@ class SacredPlotter(object):
 				if not include_running:
 					query["status"] = {"$ne": "RUNNING"}
 
-				res.append((setup, self.get_values(f"{metrics_key}{metric}", query)))
+				if start_time is not None:
+					query["start_time"] = {"$gte": parse_date(start_time)}
+
+				values = self.get_values(f"{metrics_key}{metric}", query)
+				if len(values) == 0:
+					warnings.warn(f"No values found for query { query_to_str(query) }")
+				else:
+					res.append((setup, values))
 
 			row, col = np.unravel_index(i, (n_rows, n_cols))
 
+			if len(res) == 0:
+				warnings.warn(f"No setups for metric \"{metric}\" collected!")
+
+			else:
 
-			ax = plt.subplot(grid[row, col])
-			labels, values = zip(*[(setup_to_label(setup, vals), vals) for setup, vals in res if vals])
-			ax.boxplot(values, labels=labels, **plot_kwargs)
-			ax.set_title(f"Metric: {metric}")
+				ax = plt.subplot(grid[row, col])
+				labels, values = zip(*[(setup_to_label(setup, vals), vals) for setup, vals in res if vals])
+				ax.boxplot(values, labels=labels, **plot_kwargs)
+				ax.set_title(f"Metric: {metric}")
 
-			axs.append(ax)
+				axs.append(ax)
 
 		return axs