Script for generating example plots for matplotlib style (with math equations)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.colors as mcolors | |
# here 'gs-style' is the name reference to the custom matplotlib style | |
# typically saved to ~/.matplotlib/stylelib/ | |
plt.style.use('gs-style') | |
# Fixing random state for reproducibility | |
np.random.seed(19680801) | |
def plot_scatter(ax, prng, nb_samples=100): | |
"""Scatter plot.""" | |
for mu, sigma, marker in [(-.5, 0.75, 'o'), (0.75, 1., 's')]: | |
x, y = prng.normal(loc=mu, scale=sigma, size=(2, nb_samples)) | |
ax.plot(x, y, ls='none', marker=marker) | |
ax.set_xlabel('X-label') | |
ax.set_title('Axes title') | |
return ax | |
def plot_colored_lines(ax): | |
"""Plot lines with colors following the style color cycle.""" | |
t = np.linspace(-10, 10, 100) | |
def sigmoid(t, t0): | |
return 1 / (1 + np.exp(-(t - t0))) | |
nb_colors = len(plt.rcParams['axes.prop_cycle']) | |
shifts = np.linspace(-5, 5, nb_colors) | |
amplitudes = np.linspace(1, 1.5, nb_colors) | |
for t0, a in zip(shifts, amplitudes): | |
ax.plot(t, a * sigmoid(t, t0), '-') | |
ax.set_xlim(-10, 10) | |
ax.set_xlabel(r'Wavevector $k\sigma$') | |
ax.set_ylabel(r'S(k)') | |
ax.text(0.4,0.85,r'\TeX\ is Number $\displaystyle\sum_{n=1}^\infty' | |
r'\frac{-e^{i\pi}}{2^n}$!', horizontalalignment='center', | |
verticalalignment='center', transform=ax.transAxes) | |
return ax | |
def plot_bar_graphs(ax, prng, min_value=5, max_value=25, nb_samples=5): | |
"""Plot two bar graphs side by side, with letters as x-tick labels.""" | |
x = np.arange(nb_samples) | |
ya, yb = prng.randint(min_value, max_value, size=(2, nb_samples)) | |
width = 0.25 | |
ax.bar(x, ya, width) | |
ax.bar(x + width, yb, width, color='C2') | |
ax.set_xticks(x + width, labels=['a', 'b', 'c', 'd', 'e']) | |
return ax | |
def plot_colored_circles(ax, prng, nb_samples=15): | |
""" | |
Plot circle patches. | |
NB: draws a fixed amount of samples, rather than using the length of | |
the color cycle, because different styles may have different numbers | |
of colors. | |
""" | |
for sty_dict, j in zip(plt.rcParams['axes.prop_cycle'], range(nb_samples)): | |
ax.add_patch(plt.Circle(prng.normal(scale=3, size=2), | |
radius=1.0, color=sty_dict['color'])) | |
# Force the limits to be the same across the styles (because different | |
# styles may have different numbers of available colors). | |
ax.set_xlim([-4, 8]) | |
ax.set_ylim([-5, 6]) | |
ax.set_aspect('equal', adjustable='box') # to plot circles as circles | |
return ax | |
def plot_image_and_patch(ax, prng, size=(20, 20)): | |
"""Plot an image with random values and superimpose a circular patch.""" | |
values = prng.random_sample(size=size) | |
ax.imshow(values, interpolation='none') | |
c = plt.Circle((5, 5), radius=5, label='patch') | |
ax.add_patch(c) | |
# Remove ticks | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
def plot_histograms(ax, prng, nb_samples=10000): | |
"""Plot 4 histograms and a text annotation.""" | |
params = ((10, 10), (4, 12), (50, 12), (6, 55)) | |
for a, b in params: | |
values = prng.beta(a, b, size=nb_samples) | |
ax.hist(values, histtype="stepfilled", bins=30, | |
alpha=0.8, density=True) | |
# Add a small annotation. | |
ax.annotate('Annotation', xy=(0.25, 4.25), | |
xytext=(0.9, 0.9), textcoords=ax.transAxes, | |
va="top", ha="right", | |
bbox=dict(boxstyle="round", alpha=0.2), | |
arrowprops=dict( | |
arrowstyle="->", | |
connectionstyle="angle,angleA=-95,angleB=35,rad=10"), | |
) | |
return ax | |
def plot_figure(style_label=""): | |
"""Setup and plot the demonstration figure with a given style.""" | |
# Use a dedicated RandomState instance to draw the same "random" values | |
# across the different figures. | |
prng = np.random.RandomState(96917002) | |
fig, axs = plt.subplots(ncols=6, nrows=1, num=style_label, | |
figsize=(14.8, 2.7), constrained_layout=True) | |
# make a suptitle, in the same style for all subfigures, | |
# except those with dark backgrounds, which get a lighter color: | |
background_color = mcolors.rgb_to_hsv( | |
mcolors.to_rgb(plt.rcParams['figure.facecolor']))[2] | |
if background_color < 0.5: | |
title_color = [0.8, 0.8, 1] | |
else: | |
title_color = np.array([19, 6, 84]) / 256 | |
fig.suptitle(style_label, x=0.01, ha='left', color=title_color, | |
fontsize=14, fontweight='normal') | |
plot_scatter(axs[0], prng) | |
plot_image_and_patch(axs[1], prng) | |
plot_bar_graphs(axs[2], prng) | |
plot_colored_circles(axs[3], prng) | |
plot_colored_lines(axs[4]) | |
plot_histograms(axs[5], prng) | |
plot_figure(style_label='') | |
plt.savefig('custom_matplotlib_style.png', dpi=300) | |
fig, ax = plt.subplots(figsize=(8,8/1.5)) | |
# interface tracking profiles | |
N = 500 | |
delta = 0.6 | |
X = np.linspace(-1, 1, N) | |
ax.plot(X, (1 - np.tanh(4 * X / delta)) / 2, # phase field tanh profiles | |
X, (1.4 + np.tanh(4 * X / delta)) / 4, "C2", # composition profile | |
X, X < 0, "k--") # sharp interface | |
# legend | |
ax.legend(("phase field", "level set", "sharp interface"), | |
shadow=True, loc=(0.01, 0.48), handlelength=1.5, fontsize=16) | |
# the arrow | |
ax.annotate("", xy=(-delta / 2., 0.1), xytext=(delta / 2., 0.1), | |
arrowprops=dict(arrowstyle="<->", connectionstyle="arc3")) | |
ax.text(0, 0.1, r"$\delta$", | |
color="black", fontsize=24, | |
horizontalalignment="center", verticalalignment="center", | |
bbox=dict(boxstyle="round", fc="white", ec="black", pad=0.2)) | |
# Use tex in labels | |
ax.set_xticks([-1, 0, 1]) | |
ax.set_xticklabels(["$-1$", r"$\pm 0$", "$+1$"], color="k", size=20) | |
# Left Y-axis labels, combine math mode and text mode | |
ax.set_ylabel(r"\bf{phase field} $\phi$", color="C0", fontsize=20) | |
ax.set_yticks([0, 0.5, 1]) | |
ax.set_yticklabels([r"\bf{0}", r"\bf{.5}", r"\bf{1}"], color="k", size=20) | |
# Right Y-axis labels | |
ax.text(1.02, 0.5, r"\bf{level set} $\phi$", | |
color="C2", fontsize=20, rotation=90, | |
horizontalalignment="left", verticalalignment="center", | |
clip_on=False, transform=ax.transAxes) | |
# Use multiline environment inside a `text`. | |
# level set equations | |
eq1 = (r"\begin{eqnarray*}" | |
r"|\nabla\phi| &=& 1,\\" | |
r"\frac{\partial \phi}{\partial t} + U|\nabla \phi| &=& 0 " | |
r"\end{eqnarray*}") | |
ax.text(1, 0.9, eq1, color="C2", fontsize=18, | |
horizontalalignment="right", verticalalignment="top") | |
# phase field equations | |
eq2 = (r"\begin{eqnarray*}" | |
r"\mathcal{F} &=& \int f\left( \phi, c \right) dV, \\ " | |
r"\frac{ \partial \phi } { \partial t } &=& -M_{ \phi } " | |
r"\frac{ \delta \mathcal{F} } { \delta \phi }" | |
r"\end{eqnarray*}") | |
ax.text(0.18, 0.18, eq2, color="C0", fontsize=16) | |
ax.text(-1, .30, r"gamma: $\gamma$", color="r", fontsize=20) | |
ax.text(-1, .18, r"Omega: $\Omega$", color="b", fontsize=20) | |
plt.savefig('custom_matplotlib_style_2.png') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment