Skip to content

Instantly share code, notes, and snippets.

@strubell
Last active March 17, 2019 02:11
Show Gist options
  • Save strubell/539a9ccba0de6a925b108c3c728c8316 to your computer and use it in GitHub Desktop.
Save strubell/539a9ccba0de6a925b108c3c728c8316 to your computer and use it in GitHub Desktop.
plot attention
from __future__ import division
from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
import os
import re
import string
def sorted_alphanum(l):
convert = lambda text: int(text) if text.isdigit() else text
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted(l, key=alphanum_key)
data = np.load("pat_attn.npz")
lines = map(lambda x: x.split('\t'), open("attention_weights.txt", 'r').readlines())
save_dir = "attn_pat"
docs = []
current_doc = []
current_batch = current_example = current_tok = "0"
for line in lines:
_, this_batch, this_example, this_tok = re.split("[a-z_]*", line[0])
if this_example != current_example:
docs.append(map(list, zip(*current_doc)))
current_doc = []
current_example = this_example
else:
if line[1] != "<PAD>":
current_doc.append(map(string.strip, line[1:]))
docs.append(map(list, zip(*current_doc)))
# index of deepest layer
max_layer = 1
# colors for highlighting entities & relations labels
colors_map = {0: 'black', 1: 'red', 2: 'blue'}
batch_size = data[data.files[0]].shape[0]
# only plot a subsample of this many documents
sample = False
num_samples = 20
if sample:
total_num = batch_size * len(data.files)
samples = np.random.choice(total_num, num_samples, replace=False)
print("Plotting samples: %s" % (' '.join(map(str, sorted(samples)))))
# For each batch+layer
batch_sum = 0
fig, ax = plt.subplots()
for arr_name in sorted_alphanum(data.files):
print("Processing %s" % arr_name)
split_name = re.split("[a-z_]*", arr_name)
batch = int(split_name[1])
layer = int(split_name[2])
idx_in_batch = 0
# For each element in the batch (one layer)
# if layer == max_layer and batch > 0:
for b_i, arrays in enumerate(data[arr_name]):
doc_idx = batch_sum + b_i
if not sample or doc_idx in samples:
if sample:
print("Taking batch: %d, doc: %d, layer: %d" % (batch, doc_idx, layer))
width = arrays.shape[-1]
doc = docs[doc_idx]
words = doc[0]
e1 = np.array(map(int, doc[1]))
e2 = np.array(map(int, doc[2]))
doc_len = len(words)
tick_colors = map(colors_map.get, (2 * e1 + e2)[:doc_len])
# For each attention head
for head, arr in enumerate(arrays):
name = "doc%d_layer%d_head%d" % (doc_idx, layer, head)
ax.set_title(name, fontsize=8)
# axis 1 of arr sums to 1
res = ax.imshow(arr[:doc_len, :doc_len], cmap=plt.cm.viridis, interpolation=None)
ax.set_xticks(range(doc_len))
ax.set_yticks(range(doc_len))
ax.set_xticklabels(words, rotation=75, fontsize=2)
ax.set_yticklabels(words, fontsize=2)
for t, c in zip(ax.get_xticklabels(), tick_colors):
t.set_color(c)
for t, c in zip(ax.get_yticklabels(), tick_colors):
t.set_color(c)
fig.tight_layout()
fig.savefig(os.path.join(save_dir, name + ".pdf"))
ax.clear()
if layer == max_layer:
batch_sum += batch_size
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment