Skip to content

Instantly share code, notes, and snippets.

@kingjr
Last active December 6, 2019 15:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kingjr/552310cce64dd0860faaf04561799871 to your computer and use it in GitHub Desktop.
Save kingjr/552310cce64dd0860faaf04561799871 to your computer and use it in GitHub Desktop.
architecture_plot
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
n_layers = 4
n_steps = 4
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[10, 3])
arrow = dict(width=.01, head_width=.1, color='C0')
def plot_architecture(ax, n_layers=n_layers):
for step in range(n_steps):
for layer in range(1, n_layers):
for r in np.linspace(-.15, .15, 4):
ax.scatter(step+r, layer+r, color='w', edgecolor='k', zorder=1000-r, s=100)
for i, word in enumerate('brains are really great'.split()):
ax.text(i, -.5, word, horizontalalignment='center', color='k')
ax.set_aspect('equal')
for spine in ('top', 'right', 'left', 'bottom'):
ax.spines[spine].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlim(-.5, n_steps-.5)
ax.set_ylim(-1, n_layers-.5)
def plot_word_embedding(ax):
plot_architecture(ax, 2)
for step in range(n_steps):
for layer in range(1):
ax.arrow(step, layer, 0, .6, **arrow)
ax.set_title('Word Embedding')
def plot_lstm(ax):
plot_architecture(ax)
# feedforward
for step in range(n_steps):
for layer in range(n_layers-1):
ax.arrow(step, layer, 0, .6, **arrow)
# recurrence
for step in range(n_steps-1):
for layer in range(1, n_layers):
ax.arrow(step, layer, .6, 0, **arrow)
ax.set_title('Causal LSTM')
def plot_transformer(ax):
plot_architecture(ax)
# feedforward
for step in range(n_steps):
for layer in range(n_layers-1):
ax.arrow(step, layer, 0, .6, **arrow)
# attention
for step in range(n_steps-1):
for layer in range(n_layers-1):
for reach in range(n_steps - 1):
if (reach + step + 1) >= n_steps:
continue
ax.arrow(step, layer, (reach+1.) - .2, .6, **arrow)
ax.set_title('Causal Transformer')
plot_word_embedding(axes[0])
plot_lstm(axes[1])
plot_transformer(axes[2])
fig.tight_layout()
fig.savefig('models.svg')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment