Skip to content

Instantly share code, notes, and snippets.

@alchzh
Last active May 25, 2023 13:36
Show Gist options
  • Save alchzh/7f5fb3d01e53b2ad737db486d065d1a0 to your computer and use it in GitHub Desktop.
Save alchzh/7f5fb3d01e53b2ad737db486d065d1a0 to your computer and use it in GitHub Desktop.
Plotting multiple titles on an axis via monkey-patching (for corner.py)
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import matplotlib.lines as mlines
import corner
from types import MethodType
# Matplotlib text doesn't support multiple colors, so we need to perform
# some hacks to place text neatly
# Ideally should probably be implemented as a new Artist class holding
# text objects
def get_neighbor_transform(text, direction="top", npad=0):
"""
Determines the transform for the _next_ Text plotted with
set_title_artist. For alignment purposes, requires rendering
the text and getting its extent.
"""
text.draw(text.figure.canvas.get_renderer())
ex = text.get_window_extent()
if direction == "up":
x = 0
y = ex.height + npad
elif direction == "down":
x = 0
y = -ex.height - npad
elif direction == "right":
halign = text.get_horizontalalignment()
if halign == "left":
x = ex.width + npad
elif halign == "center":
x = ex.width / 2 + npad
elif halign == "right":
x = 0
y = 0
elif direction == "left":
halign = text.get_horizontalalignment()
if halign == "left":
x = 0
elif halign == "center":
x = -ex.width / 2 - npad
elif halign == "right":
x = -ex.width - npad
y = 0
else:
raise ValueError("loc must be top, bottom, or right")
return mtransforms.offset_copy(
text.get_transform(), x=x, y=y,
fig=text.figure, units='points'
)
# Determines the attribute of an Axes arist corresponding to the correct
# Text artist for the 'loc' parameter
LOC_ATTR = {
"center": "title",
"right": "_right_title",
"left": "_left_title"
}
def set_title_neighbor(ax, label, loc="center", **kwargs):
"""
corner.py is inflexible and only outputs the medians and quartiles
through calling Axes.set_title. We monkey patch in the collation
behavior using this method, deferring to the original implementation
whern the user hasn't called init_neighbor_title on that axis and
title loc.
"""
neighbor_entry = ax._neighbor_titles.get(loc, None)
if neighbor_entry is None:
return ax._orig_set_title(label, loc=loc, **kwargs)
neighbor_artists = neighbor_entry["artists"]
direction = neighbor_entry["direction"]
if not neighbor_artists:
text = ax._orig_set_title(label, loc=loc, **kwargs)
neighbor_artists.append(text)
else:
base_title = ax.__getattribute__(LOC_ATTR[loc])
kwargs.pop("horizontalignment", None)
kwargs.pop("verticalalignment", None)
if direction == "up" or direction == "down":
halign = base_title.get_horizontalalignment()
elif direction == "right":
halign = "left"
elif direction == "left":
halign = "right"
else:
raise ValueError("direction must be up, down, right or left")
text = ax.text(
*base_title.get_position(), label,
transform=neighbor_entry["next_transform"],
horizontalalignment=halign,
verticalalignment=base_title.get_verticalalignment(),
fontproperties=base_title.get_fontproperties(),
**kwargs
)
neighbor_artists.append(text)
neighbor_entry["next_transform"] = get_neighbor_transform(
text,
direction=direction,
npad=neighbor_entry["npad"]
)
def get_all_title_artists(ax):
"""
Returns all title artists.
Useful for the "bbox_extra_artists" parameter to Figure.savefig
or modifying attributes of all titles.
"""
_loc_attr = dict(**LOC_ATTR)
if hasattr(ax, "_neighbor_titles"):
for loc, neighbor_entry in ax._neighbor_titles.items():
neighbor_artists = neighbor_entry["artists"]
if not neighbor_artists:
continue
_loc_attr.pop(loc)
yield from neighbor_artists
for attr in _loc_attr.values():
yield ax.__getattribute__(attr)
def init_neighbor_title(ax, loc="center", direction="right", npad=2):
"""
Initializes the neighbor title behavior, allowing set_title to create
multiple artists with, say, different colors, by monkey-patching the
set_title method.
Only affects the title with the loc specified, so if you have both
a left title and a center title you need to call the method on each.
direction can be up, down, left, or right and determines where each
subsequent neighbor will be placed.
npad determines the spacing between each Text artist.
Can be called multiple times to change the direction and npad settings,
although the changes will only affect future Text artists.
"""
if getattr(ax.set_title, "func", None) is not set_title_neighbor:
ax._orig_set_title = ax.set_title
ax.set_title = MethodType(set_title_neighbor, ax)
ax.get_all_title_artists = MethodType(get_all_title_artists, ax)
ax._neighbor_titles = dict()
neighbor_entry = ax._neighbor_titles.setdefault(loc, dict())
neighbor_entry["direction"] = direction
neighbor_entry["npad"] = npad
neighbor_artists = neighbor_entry.setdefault("artists", list())
if not neighbor_artists and ax.get_title(loc=loc):
title = ax.__getattribute__(LOC_ATTR[loc])
neighbor_artists.append(title)
title.draw(ax.get_figure().canvas.get_renderer())
if neighbor_artists:
neighbor_entry["next_transform"] = get_neighbor_transform(
neighbor_artists[-1],
direction=direction,
npad=npad
)
fig = plt.figure(figsize=(8, 8))
def rescale_alpha(theta):
"""
Converts alpha from log base e to log base 10 for graphing
"""
return theta / (np.log(10), 1)
kwargs = dict(
plot_datapoints=False,
data_kwargs=dict(
alpha=1/255,
),
plot_countours=True,
plot_density=True,
show_titles=True,
labelpad=-0.15,
levels=1.0 - np.exp(-0.5 * np.arange(1.0, 2.1, 1.0) ** 2),
bins=50
)
# Ensures the first dataset we plot is the one on top.
primary = dict(
contour_kwargs=dict(
zorder=10,
),
contourf_kwargs=dict(
zorder=10,
),
pcolor_kwargs=dict(
zorder=10,
)
)
# Sets both the color of the plotted elements and the
# associated title.
def _color_and_title(color, **title_kwargs):
return dict(
color=color,
title_kwargs=dict(
color=color,
**title_kwargs
)
)
"""
Make corner plots from a subset of the samples of
Each theta array is a AxBxC array where A is the
number of walkers, B is the number of MCMC draws
per walker, and C is the dimension of the parameter
array.
draw_sample is an array of indices corresponding to
the samples we want to actually visualize (set to
slice(None, None) to use all of them, or slice(N, None))
to skip N samples of burn_in.
"""
corner.corner(
rescale_alpha(theta_uniform[:, draw_sample, :]),
fig=fig, **_color_and_title("black", loc="left"),
labels=[r"$\log_{10} \alpha$", r"T [K]", r"Noise"],
**primary, **kwargs
)
for ax in fig.axes:
init_neighbor_title(ax, loc="left")
corner.corner(
rescale_alpha(theta_uniform_T4[:, draw_sample, :]),
fig=fig, **_color_and_title("tab:red", loc="left"),
**kwargs
)
corner.corner(
rescale_alpha(theta_log_uniform[:, draw_sample, :]),
fig=fig, **_color_and_title("tab:blue", loc="left"),
**kwargs
)
# Create a legend for our figure by finding the histogram
# polygons created by corner.corner.
#
# By default Polygons will display as boxes in the legend
# but I prefer them to be single colored lines.
poly = [
mlines.Line2D([], [], color=a.get_edgecolor())
for a in fig.axes[0].get_children()
if isinstance(a, mpatches.Polygon)
][:3]
poly[0].set_label("Uniform $T$ Prior")
poly[1].set_label("Fourth Power Uniform $T$ Prior")
poly[2].set_label("Log Uniform $T$ Prior")
fig.axes[1].legend(handles=poly)
save_kwarg = dict(
bbox_inches="tight",
bbox_extra_artists=[title for ax in fig.axes for title in ax.get_all_title_artists()]
)
fig.savefig("multicolor.png", dpi=300, **save_kwarg)
fig.savefig("multicolor.pdf", **save_kwarg)
@sahiljhawar
Copy link

Can you also provide dummy data, in order to see what actual input and output looks like?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment