Last active
March 17, 2019 02:11
-
-
Save strubell/539a9ccba0de6a925b108c3c728c8316 to your computer and use it in GitHub Desktop.
plot attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
from __future__ import print_function | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
import re | |
import string | |
def sorted_alphanum(l): | |
convert = lambda text: int(text) if text.isdigit() else text | |
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] | |
return sorted(l, key=alphanum_key) | |
data = np.load("pat_attn.npz") | |
lines = map(lambda x: x.split('\t'), open("attention_weights.txt", 'r').readlines()) | |
save_dir = "attn_pat" | |
docs = [] | |
current_doc = [] | |
current_batch = current_example = current_tok = "0" | |
for line in lines: | |
_, this_batch, this_example, this_tok = re.split("[a-z_]*", line[0]) | |
if this_example != current_example: | |
docs.append(map(list, zip(*current_doc))) | |
current_doc = [] | |
current_example = this_example | |
else: | |
if line[1] != "<PAD>": | |
current_doc.append(map(string.strip, line[1:])) | |
docs.append(map(list, zip(*current_doc))) | |
# index of deepest layer | |
max_layer = 1 | |
# colors for highlighting entities & relations labels | |
colors_map = {0: 'black', 1: 'red', 2: 'blue'} | |
batch_size = data[data.files[0]].shape[0] | |
# only plot a subsample of this many documents | |
sample = False | |
num_samples = 20 | |
if sample: | |
total_num = batch_size * len(data.files) | |
samples = np.random.choice(total_num, num_samples, replace=False) | |
print("Plotting samples: %s" % (' '.join(map(str, sorted(samples))))) | |
# For each batch+layer | |
batch_sum = 0 | |
fig, ax = plt.subplots() | |
for arr_name in sorted_alphanum(data.files): | |
print("Processing %s" % arr_name) | |
split_name = re.split("[a-z_]*", arr_name) | |
batch = int(split_name[1]) | |
layer = int(split_name[2]) | |
idx_in_batch = 0 | |
# For each element in the batch (one layer) | |
# if layer == max_layer and batch > 0: | |
for b_i, arrays in enumerate(data[arr_name]): | |
doc_idx = batch_sum + b_i | |
if not sample or doc_idx in samples: | |
if sample: | |
print("Taking batch: %d, doc: %d, layer: %d" % (batch, doc_idx, layer)) | |
width = arrays.shape[-1] | |
doc = docs[doc_idx] | |
words = doc[0] | |
e1 = np.array(map(int, doc[1])) | |
e2 = np.array(map(int, doc[2])) | |
doc_len = len(words) | |
tick_colors = map(colors_map.get, (2 * e1 + e2)[:doc_len]) | |
# For each attention head | |
for head, arr in enumerate(arrays): | |
name = "doc%d_layer%d_head%d" % (doc_idx, layer, head) | |
ax.set_title(name, fontsize=8) | |
# axis 1 of arr sums to 1 | |
res = ax.imshow(arr[:doc_len, :doc_len], cmap=plt.cm.viridis, interpolation=None) | |
ax.set_xticks(range(doc_len)) | |
ax.set_yticks(range(doc_len)) | |
ax.set_xticklabels(words, rotation=75, fontsize=2) | |
ax.set_yticklabels(words, fontsize=2) | |
for t, c in zip(ax.get_xticklabels(), tick_colors): | |
t.set_color(c) | |
for t, c in zip(ax.get_yticklabels(), tick_colors): | |
t.set_color(c) | |
fig.tight_layout() | |
fig.savefig(os.path.join(save_dir, name + ".pdf")) | |
ax.clear() | |
if layer == max_layer: | |
batch_sum += batch_size |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment