Last active
May 22, 2023 11:55
-
-
Save mattpitkin/07627fcd721821220ef9290c3a7314fd to your computer and use it in GitHub Desktop.
A network plot of 2x2 grids
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from matplotlib import pyplot as plt | |
from matplotlib.patches import Rectangle | |
fig, ax = plt.subplots() | |
def draw_2x2_grid(ax, xy, bwidth=0.25, bheight=0.25, numbers=[]): | |
for i in range(2): | |
x = xy[0] + i * bwidth | |
for j in range(2): | |
y = xy[1] + j * bheight | |
r = Rectangle((x, y), bwidth, bheight, facecolor="none", edgecolor="k", lw=3) | |
ax.add_patch(r) | |
ax.text( | |
x + bwidth / 2, | |
y + bheight / 2, | |
str(numbers[i * 2 + j]), | |
va="center", | |
ha="center", | |
fontsize=20, | |
) | |
def draw_arrow(ax, xystart, xyend, textabove=None, textbelow=None, arrowkwargs={}): | |
""" | |
Draw and annotate an arrow between two points. | |
""" | |
# get displacements dx, dy | |
dx = xyend[0] - xystart[0] | |
dy = xyend[1] - xystart[1] | |
angle = np.rad2deg(np.arctan2(dy, dx)) | |
ax.arrow(xystart[0], xystart[1], dx, dy, **arrowkwargs) | |
# add text | |
if textabove is not None: | |
ax.text( | |
xystart[0] + dx / 2, | |
xystart[1] + dy / 2, | |
textabove, | |
ha="center", | |
va="bottom", | |
transform_rotates_text=True, | |
rotation=angle, | |
rotation_mode='anchor', | |
fontsize=12, | |
) | |
if textbelow is not None: | |
ax.text( | |
xystart[0] + dx / 2, | |
xystart[1] + dy / 2, | |
textbelow, | |
ha="center", | |
va="top", | |
transform_rotates_text=True, | |
rotation=angle, | |
rotation_mode='anchor', | |
fontsize=12, | |
) | |
# width and height of each subbox within a 2x2 grid | |
bwidth = 0.3 | |
bheight = 0.3 | |
# bottom left x, y positions of boxes | |
boxpositions = [ | |
(0, 0), | |
(0 + 3.5 * bwidth, 0 - 3 * bheight), | |
(0 + 3.5 * bwidth, 0 + 2 * bheight), | |
] | |
# numbers in boxes | |
boxnumbers = [ | |
[0, 1, 0, 0], | |
[2, 3, 2, 0], | |
[1, 2, 2, 1], | |
] | |
# draw boxes | |
for bpos, bnum in zip(boxpositions, boxnumbers): | |
draw_2x2_grid(ax, bpos, bwidth=bwidth, bheight=bheight, numbers=bnum) | |
# draw arrows | |
draw_arrow( | |
ax, | |
(boxpositions[0][0] + 2 * bwidth, boxpositions[0][1] + bheight), | |
(boxpositions[1][0], boxpositions[1][1] + bheight), | |
textbelow="$t = 2$", | |
textabove="$a_2(0)$", | |
arrowkwargs={"color": "r", "head_width": 0.05, "length_includes_head": True, "lw": 2}, | |
) | |
draw_arrow( | |
ax, | |
(boxpositions[0][0] + 2 * bwidth, boxpositions[0][1] + bheight), | |
(boxpositions[2][0], boxpositions[2][1] + bheight), | |
textbelow="$t = 1$", | |
textabove="$a_1(1)$", | |
arrowkwargs={"color": "k", "head_width": 0.05, "length_includes_head": True, "lw": 2}, | |
) | |
ax.axis("equal") | |
ax.set_axis_off() | |
fig.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment