Skip to content

Instantly share code, notes, and snippets.

@mattpitkin
Last active May 22, 2023 11:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattpitkin/07627fcd721821220ef9290c3a7314fd to your computer and use it in GitHub Desktop.
Save mattpitkin/07627fcd721821220ef9290c3a7314fd to your computer and use it in GitHub Desktop.
A network plot of 2x2 grids
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
fig, ax = plt.subplots()
def draw_2x2_grid(ax, xy, bwidth=0.25, bheight=0.25, numbers=[]):
for i in range(2):
x = xy[0] + i * bwidth
for j in range(2):
y = xy[1] + j * bheight
r = Rectangle((x, y), bwidth, bheight, facecolor="none", edgecolor="k", lw=3)
ax.add_patch(r)
ax.text(
x + bwidth / 2,
y + bheight / 2,
str(numbers[i * 2 + j]),
va="center",
ha="center",
fontsize=20,
)
def draw_arrow(ax, xystart, xyend, textabove=None, textbelow=None, arrowkwargs={}):
"""
Draw and annotate an arrow between two points.
"""
# get displacements dx, dy
dx = xyend[0] - xystart[0]
dy = xyend[1] - xystart[1]
angle = np.rad2deg(np.arctan2(dy, dx))
ax.arrow(xystart[0], xystart[1], dx, dy, **arrowkwargs)
# add text
if textabove is not None:
ax.text(
xystart[0] + dx / 2,
xystart[1] + dy / 2,
textabove,
ha="center",
va="bottom",
transform_rotates_text=True,
rotation=angle,
rotation_mode='anchor',
fontsize=12,
)
if textbelow is not None:
ax.text(
xystart[0] + dx / 2,
xystart[1] + dy / 2,
textbelow,
ha="center",
va="top",
transform_rotates_text=True,
rotation=angle,
rotation_mode='anchor',
fontsize=12,
)
# width and height of each subbox within a 2x2 grid
bwidth = 0.3
bheight = 0.3
# bottom left x, y positions of boxes
boxpositions = [
(0, 0),
(0 + 3.5 * bwidth, 0 - 3 * bheight),
(0 + 3.5 * bwidth, 0 + 2 * bheight),
]
# numbers in boxes
boxnumbers = [
[0, 1, 0, 0],
[2, 3, 2, 0],
[1, 2, 2, 1],
]
# draw boxes
for bpos, bnum in zip(boxpositions, boxnumbers):
draw_2x2_grid(ax, bpos, bwidth=bwidth, bheight=bheight, numbers=bnum)
# draw arrows
draw_arrow(
ax,
(boxpositions[0][0] + 2 * bwidth, boxpositions[0][1] + bheight),
(boxpositions[1][0], boxpositions[1][1] + bheight),
textbelow="$t = 2$",
textabove="$a_2(0)$",
arrowkwargs={"color": "r", "head_width": 0.05, "length_includes_head": True, "lw": 2},
)
draw_arrow(
ax,
(boxpositions[0][0] + 2 * bwidth, boxpositions[0][1] + bheight),
(boxpositions[2][0], boxpositions[2][1] + bheight),
textbelow="$t = 1$",
textabove="$a_1(1)$",
arrowkwargs={"color": "k", "head_width": 0.05, "length_includes_head": True, "lw": 2},
)
ax.axis("equal")
ax.set_axis_off()
fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment