Skip to content

Instantly share code, notes, and snippets.

@muellerzr
Created April 12, 2020 19:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save muellerzr/fe87df61f26be192633bc55530fb5336 to your computer and use it in GitHub Desktop.
Save muellerzr/fe87df61f26be192633bc55530fb5336 to your computer and use it in GitHub Desktop.
from fastai2.text.all import *
import html
from IPython.display import display, HTML
import matplotlib.cm as cm
def _eval_dropouts(mod):
module_name = mod.__class__.__name__
if 'Dropout' in module_name or 'BatchNorm' in module_name: mod.training = False
for module in mod.children(): _eval_dropouts(module)
def intrinsic_attention(learn, text, class_id=None):
"Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`."
learn.model.train()
_eval_dropouts(learn.model)
learn.model.zero_grad()
learn.model.reset()
dl = dls.test_dl([text])
ids = dl.one_batch()[0]
emb = learn.model[0].module.encoder(batch).detach().requires_grad_(True)
lstm = learn.model[0].module(emb, True)
learn.model.eval()
cl = learn.model[1]((lstm, torch.zeros_like(batch).bool(),))[0].softmax(dim=-1)
if class_id is None: class_id = cl.argmax()
cl[0][class_id].backward()
attn = emb.grad.squeeze().abs().sum(dim=-1)
attn /= attn.max()
tok, _ = learn.dls.decode_batch((*tuplify(batch), *tuplify(cl)))[0]
return tok, attn
def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0):
"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
c = cmap(x)
rgb = (np.array(c[:-1]) * 255).astype(int)
a = c[-1] * alpha_mult
return tuple(rgb.tolist() + [a])
def piece_attn_html(pieces, attns, sep=' ', **kwargs):
html_code,spans = ['<span style="font-family: monospace;">'], []
for p, a in zip(pieces, attns):
p = html.escape(p)
c = str(value2rgba(a, alpha_mult=0.5, **kwargs))
spans.append(f'<span title="{a:.3f}" style="background-color: rgba{c};">{p}</span>')
html_code.append(sep.join(spans))
html_code.append('</span>')
return ''.join(html_code)
def show_piece_attn(*args, **kwargs):
from IPython.display import display, HTML
display(HTML(piece_attn_html(*args, **kwargs)))
def html_intrinsic_attention(learn, text:str, class_id:int=None, **kwargs)->str:
text, attn = intrinsic_attention(learn, text, class_id)
return piece_attn_html(text.split(), to_np(attn), **kwargs)
def show_intrinsic_attention(learn, text:str, class_id:int=None, **kwargs)->None:
text, attn = intrinsic_attention(learn, text, class_id)
show_piece_attn(text.split(), to_np(attn), **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment