Skip to content

Instantly share code, notes, and snippets.

@DiKorsch
Last active September 27, 2019 08:41
Show Gist options
  • Save DiKorsch/7cf60a7101e17e138bcddcb0fea6b06a to your computer and use it in GitHub Desktop.
Save DiKorsch/7cf60a7101e17e138bcddcb0fea6b06a to your computer and use it in GitHub Desktop.
Plotting of experiment results stored in sacred
creds = dict(user="username", password="very_secure", db_name="sacred_experiments")
plotter = SacredPlotter(creds)
fig = plt.figure()
plotter.plot(
metrics=["accuracy", "loss"],
setups=[
# cnn_type
"gmm",
"fve"
],
query_factory=lambda setup: {
"experiment.name": "specific experiment name",
"config.cnn_type": setup,
},
setup_to_label=lambda setup, values: f"{cnn_type}\n{len(values)}",
include_running=False,
metrics_key="validation/",
# plot_kwargs
showfliers=True
)
plt.show()
plt.close()
import pymongo as pym
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
class SacredPlotter(object):
@staticmethod
def auth_url(creds, host="localhost"):
url = "mongodb://{user}:{password}@{host}:27017/{db_name}?authSource=admin".format(
host=host, **creds)
return url
def __init__(self, creds):
super(SacredPlotter, self).__init__()
self.client = pym.MongoClient(SacredPlotter.auth_url(creds))
self.db = self.client[creds["db_name"]]
self.metrics = self.db["metrics"]
self.runs = self.db["runs"]
def get_values(self, metric, query):
values = []
runs = list(self.runs.find(query))
for run in runs:
accuracies = self.metrics.find_one(dict(
name=metric,
run_id=run["_id"],
))
if accuracies is None or accuracies["values"] is None:
continue
values.append(accuracies["values"][-1])
return values
def plot(self,
metrics,
setups,
query_factory,
setup_to_label,
include_running=False,
metrics_key="val/main/",
**plot_kwargs):
"""
Arguments:
- metrics: defines which metrics to plot
- setups: defines different setups, that
will be compared
- query_factory: callable; creates from the setup
specification the pymongo query
- setup_to_label: callable; converts a setup
specification into a readable label
- include_running: whether running experiments should
be included or not
- metrics_key: prefix that will be appended to each
metric name
"""
n_metrics = len(metrics)
n_cols = int(np.ceil(np.sqrt(n_metrics)))
n_rows = int(np.ceil(n_metrics / n_cols))
grid = GridSpec(n_rows, n_cols)
for i, metric in enumerate(metrics):
res = []
for setup in setups:
query = query_factory(setup)
if include_running:
query["status"] = {"$ne": "RUNNING"}
res.append((setup, self.get_values(f"{metrics_key}{metric}", query)))
row, col = np.unravel_index(i, (n_rows, n_cols))
ax = plt.subplot(grid[row, col])
labels, values = zip(*[(setup_to_label(setup, vals), vals) for setup, vals in res if vals])
ax.boxplot(values, labels=labels, **plot_kwargs)
ax.set_title(f"Metric: {metric}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment