Last active
May 25, 2023 13:36
-
-
Save alchzh/7f5fb3d01e53b2ad737db486d065d1a0 to your computer and use it in GitHub Desktop.
Plotting multiple titles on an axis via monkey-patching (for corner.py)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import matplotlib.transforms as mtransforms | |
import matplotlib.lines as mlines | |
import corner | |
from types import MethodType | |
# Matplotlib text doesn't support multiple colors, so we need to perform | |
# some hacks to place text neatly | |
# Ideally should probably be implemented as a new Artist class holding | |
# text objects | |
def get_neighbor_transform(text, direction="top", npad=0): | |
""" | |
Determines the transform for the _next_ Text plotted with | |
set_title_artist. For alignment purposes, requires rendering | |
the text and getting its extent. | |
""" | |
text.draw(text.figure.canvas.get_renderer()) | |
ex = text.get_window_extent() | |
if direction == "up": | |
x = 0 | |
y = ex.height + npad | |
elif direction == "down": | |
x = 0 | |
y = -ex.height - npad | |
elif direction == "right": | |
halign = text.get_horizontalalignment() | |
if halign == "left": | |
x = ex.width + npad | |
elif halign == "center": | |
x = ex.width / 2 + npad | |
elif halign == "right": | |
x = 0 | |
y = 0 | |
elif direction == "left": | |
halign = text.get_horizontalalignment() | |
if halign == "left": | |
x = 0 | |
elif halign == "center": | |
x = -ex.width / 2 - npad | |
elif halign == "right": | |
x = -ex.width - npad | |
y = 0 | |
else: | |
raise ValueError("loc must be top, bottom, or right") | |
return mtransforms.offset_copy( | |
text.get_transform(), x=x, y=y, | |
fig=text.figure, units='points' | |
) | |
# Determines the attribute of an Axes arist corresponding to the correct | |
# Text artist for the 'loc' parameter | |
LOC_ATTR = { | |
"center": "title", | |
"right": "_right_title", | |
"left": "_left_title" | |
} | |
def set_title_neighbor(ax, label, loc="center", **kwargs): | |
""" | |
corner.py is inflexible and only outputs the medians and quartiles | |
through calling Axes.set_title. We monkey patch in the collation | |
behavior using this method, deferring to the original implementation | |
whern the user hasn't called init_neighbor_title on that axis and | |
title loc. | |
""" | |
neighbor_entry = ax._neighbor_titles.get(loc, None) | |
if neighbor_entry is None: | |
return ax._orig_set_title(label, loc=loc, **kwargs) | |
neighbor_artists = neighbor_entry["artists"] | |
direction = neighbor_entry["direction"] | |
if not neighbor_artists: | |
text = ax._orig_set_title(label, loc=loc, **kwargs) | |
neighbor_artists.append(text) | |
else: | |
base_title = ax.__getattribute__(LOC_ATTR[loc]) | |
kwargs.pop("horizontalignment", None) | |
kwargs.pop("verticalalignment", None) | |
if direction == "up" or direction == "down": | |
halign = base_title.get_horizontalalignment() | |
elif direction == "right": | |
halign = "left" | |
elif direction == "left": | |
halign = "right" | |
else: | |
raise ValueError("direction must be up, down, right or left") | |
text = ax.text( | |
*base_title.get_position(), label, | |
transform=neighbor_entry["next_transform"], | |
horizontalalignment=halign, | |
verticalalignment=base_title.get_verticalalignment(), | |
fontproperties=base_title.get_fontproperties(), | |
**kwargs | |
) | |
neighbor_artists.append(text) | |
neighbor_entry["next_transform"] = get_neighbor_transform( | |
text, | |
direction=direction, | |
npad=neighbor_entry["npad"] | |
) | |
def get_all_title_artists(ax): | |
""" | |
Returns all title artists. | |
Useful for the "bbox_extra_artists" parameter to Figure.savefig | |
or modifying attributes of all titles. | |
""" | |
_loc_attr = dict(**LOC_ATTR) | |
if hasattr(ax, "_neighbor_titles"): | |
for loc, neighbor_entry in ax._neighbor_titles.items(): | |
neighbor_artists = neighbor_entry["artists"] | |
if not neighbor_artists: | |
continue | |
_loc_attr.pop(loc) | |
yield from neighbor_artists | |
for attr in _loc_attr.values(): | |
yield ax.__getattribute__(attr) | |
def init_neighbor_title(ax, loc="center", direction="right", npad=2): | |
""" | |
Initializes the neighbor title behavior, allowing set_title to create | |
multiple artists with, say, different colors, by monkey-patching the | |
set_title method. | |
Only affects the title with the loc specified, so if you have both | |
a left title and a center title you need to call the method on each. | |
direction can be up, down, left, or right and determines where each | |
subsequent neighbor will be placed. | |
npad determines the spacing between each Text artist. | |
Can be called multiple times to change the direction and npad settings, | |
although the changes will only affect future Text artists. | |
""" | |
if getattr(ax.set_title, "func", None) is not set_title_neighbor: | |
ax._orig_set_title = ax.set_title | |
ax.set_title = MethodType(set_title_neighbor, ax) | |
ax.get_all_title_artists = MethodType(get_all_title_artists, ax) | |
ax._neighbor_titles = dict() | |
neighbor_entry = ax._neighbor_titles.setdefault(loc, dict()) | |
neighbor_entry["direction"] = direction | |
neighbor_entry["npad"] = npad | |
neighbor_artists = neighbor_entry.setdefault("artists", list()) | |
if not neighbor_artists and ax.get_title(loc=loc): | |
title = ax.__getattribute__(LOC_ATTR[loc]) | |
neighbor_artists.append(title) | |
title.draw(ax.get_figure().canvas.get_renderer()) | |
if neighbor_artists: | |
neighbor_entry["next_transform"] = get_neighbor_transform( | |
neighbor_artists[-1], | |
direction=direction, | |
npad=npad | |
) | |
fig = plt.figure(figsize=(8, 8)) | |
def rescale_alpha(theta): | |
""" | |
Converts alpha from log base e to log base 10 for graphing | |
""" | |
return theta / (np.log(10), 1) | |
kwargs = dict( | |
plot_datapoints=False, | |
data_kwargs=dict( | |
alpha=1/255, | |
), | |
plot_countours=True, | |
plot_density=True, | |
show_titles=True, | |
labelpad=-0.15, | |
levels=1.0 - np.exp(-0.5 * np.arange(1.0, 2.1, 1.0) ** 2), | |
bins=50 | |
) | |
# Ensures the first dataset we plot is the one on top. | |
primary = dict( | |
contour_kwargs=dict( | |
zorder=10, | |
), | |
contourf_kwargs=dict( | |
zorder=10, | |
), | |
pcolor_kwargs=dict( | |
zorder=10, | |
) | |
) | |
# Sets both the color of the plotted elements and the | |
# associated title. | |
def _color_and_title(color, **title_kwargs): | |
return dict( | |
color=color, | |
title_kwargs=dict( | |
color=color, | |
**title_kwargs | |
) | |
) | |
""" | |
Make corner plots from a subset of the samples of | |
Each theta array is a AxBxC array where A is the | |
number of walkers, B is the number of MCMC draws | |
per walker, and C is the dimension of the parameter | |
array. | |
draw_sample is an array of indices corresponding to | |
the samples we want to actually visualize (set to | |
slice(None, None) to use all of them, or slice(N, None)) | |
to skip N samples of burn_in. | |
""" | |
corner.corner( | |
rescale_alpha(theta_uniform[:, draw_sample, :]), | |
fig=fig, **_color_and_title("black", loc="left"), | |
labels=[r"$\log_{10} \alpha$", r"T [K]", r"Noise"], | |
**primary, **kwargs | |
) | |
for ax in fig.axes: | |
init_neighbor_title(ax, loc="left") | |
corner.corner( | |
rescale_alpha(theta_uniform_T4[:, draw_sample, :]), | |
fig=fig, **_color_and_title("tab:red", loc="left"), | |
**kwargs | |
) | |
corner.corner( | |
rescale_alpha(theta_log_uniform[:, draw_sample, :]), | |
fig=fig, **_color_and_title("tab:blue", loc="left"), | |
**kwargs | |
) | |
# Create a legend for our figure by finding the histogram | |
# polygons created by corner.corner. | |
# | |
# By default Polygons will display as boxes in the legend | |
# but I prefer them to be single colored lines. | |
poly = [ | |
mlines.Line2D([], [], color=a.get_edgecolor()) | |
for a in fig.axes[0].get_children() | |
if isinstance(a, mpatches.Polygon) | |
][:3] | |
poly[0].set_label("Uniform $T$ Prior") | |
poly[1].set_label("Fourth Power Uniform $T$ Prior") | |
poly[2].set_label("Log Uniform $T$ Prior") | |
fig.axes[1].legend(handles=poly) | |
save_kwarg = dict( | |
bbox_inches="tight", | |
bbox_extra_artists=[title for ax in fig.axes for title in ax.get_all_title_artists()] | |
) | |
fig.savefig("multicolor.png", dpi=300, **save_kwarg) | |
fig.savefig("multicolor.pdf", **save_kwarg) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Can you also provide dummy data, in order to see what actual input and output looks like?