Created
April 12, 2020 19:59
-
-
Save muellerzr/fe87df61f26be192633bc55530fb5336 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai2.text.all import * | |
import html | |
from IPython.display import display, HTML | |
import matplotlib.cm as cm | |
def _eval_dropouts(mod): | |
module_name = mod.__class__.__name__ | |
if 'Dropout' in module_name or 'BatchNorm' in module_name: mod.training = False | |
for module in mod.children(): _eval_dropouts(module) | |
def intrinsic_attention(learn, text, class_id=None): | |
"Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`." | |
learn.model.train() | |
_eval_dropouts(learn.model) | |
learn.model.zero_grad() | |
learn.model.reset() | |
dl = dls.test_dl([text]) | |
ids = dl.one_batch()[0] | |
emb = learn.model[0].module.encoder(batch).detach().requires_grad_(True) | |
lstm = learn.model[0].module(emb, True) | |
learn.model.eval() | |
cl = learn.model[1]((lstm, torch.zeros_like(batch).bool(),))[0].softmax(dim=-1) | |
if class_id is None: class_id = cl.argmax() | |
cl[0][class_id].backward() | |
attn = emb.grad.squeeze().abs().sum(dim=-1) | |
attn /= attn.max() | |
tok, _ = learn.dls.decode_batch((*tuplify(batch), *tuplify(cl)))[0] | |
return tok, attn | |
def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0): | |
"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`." | |
c = cmap(x) | |
rgb = (np.array(c[:-1]) * 255).astype(int) | |
a = c[-1] * alpha_mult | |
return tuple(rgb.tolist() + [a]) | |
def piece_attn_html(pieces, attns, sep=' ', **kwargs): | |
html_code,spans = ['<span style="font-family: monospace;">'], [] | |
for p, a in zip(pieces, attns): | |
p = html.escape(p) | |
c = str(value2rgba(a, alpha_mult=0.5, **kwargs)) | |
spans.append(f'<span title="{a:.3f}" style="background-color: rgba{c};">{p}</span>') | |
html_code.append(sep.join(spans)) | |
html_code.append('</span>') | |
return ''.join(html_code) | |
def show_piece_attn(*args, **kwargs): | |
from IPython.display import display, HTML | |
display(HTML(piece_attn_html(*args, **kwargs))) | |
def html_intrinsic_attention(learn, text:str, class_id:int=None, **kwargs)->str: | |
text, attn = intrinsic_attention(learn, text, class_id) | |
return piece_attn_html(text.split(), to_np(attn), **kwargs) | |
def show_intrinsic_attention(learn, text:str, class_id:int=None, **kwargs)->None: | |
text, attn = intrinsic_attention(learn, text, class_id) | |
show_piece_attn(text.split(), to_np(attn), **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment