Skip to content

Instantly share code, notes, and snippets.

@tnq177
Created January 7, 2021 20:43
Show Gist options
  • Save tnq177/c493ac7cdb926b074defff2729d0ffc8 to your computer and use it in GitHub Desktop.
Save tnq177/c493ac7cdb926b074defff2729d0ffc8 to your computer and use it in GitHub Desktop.
plotting attention weights with bokeh #attention #bokeh
from bokeh.plotting import figure, output_file, save
from bokeh.palettes import Blues256
from bokeh.io import export_png
def plot_att(src, tgt, weights, out_filepath):
"""
Plot attention using Bokeh.
Output is a 2D matrix with x-axis=src, y-axis=tgt.
Each cell = attention weight between corresponding src
and tgt words.
Parameters:
src (list): list of src words
tgt (list): list of tgt words
weights (np.ndarray): 2D np array of shape [len(tgt), len(src)]
Returns:
Save the visualization in png
and html format (great for interactive visualization).
"""
# small->large value = white->blue
colormap = Blues256[::-1]
# flip tgt to display tgt words from top to bottom
tgt = tgt[::-1]
np.flip(weights, axis=0)
x_range = list(range(len(src)))
y_range = list(range(len(tgt)))
xs = []
ys = []
colors = []
xtoks = []
ytoks = []
# or we could just flatten weights
att_weights = []
for x, src_tok in enumerate(src):
for y, tgt_tok in enumerate(tgt):
xs.append(x)
ys.append(y)
color_idx = int(weights[y][x] * 256) % 256
att_weights.append(weights[y][x])
colors.append(colormap[color_idx])
xtoks.append(src_tok)
ytoks.append(tgt_tok)
data = dict(
xs=xs,
ys=ys,
colors=colors,
att_weights=att_weights,
xtoks=xtoks,
ytoks=ytoks
)
p = figure(
x_axis_location='above',
tools='hover,save',
tooltips=[('pair', '@ytoks, @xtoks'), ('weight', '@att_weights')])
# override numbers with text
p.xaxis.ticker = x_range
p.xaxis.major_label_overrides = {idx: tok for (idx, tok) in zip(x_range, src)}
p.yaxis.ticker = y_range
p.yaxis.major_label_overrides = {idx: tok for (idx, tok) in zip(y_range, tgt)}
p.plot_width = 800
p.plot_height = 800
p.grid.grid_line_color = None
p.axis.axis_line_color = None
p.axis.major_tick_line_color = None
# p.axis.major_label_text_font_size = "10px"
p.axis.major_label_standoff = 0
p.xaxis.major_label_orientation = np.pi/3
p.rect('xs', 'ys', 1.0, 1.0, color='colors', source=data, line_color=None, hover_line_color='black', hover_color='colors')
output_file(out_filepath + '.html')
save(p)
export_png(p, out_filepath + '.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment