Skip to content

Instantly share code, notes, and snippets.

@rpryzant
Last active June 20, 2022 20:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rpryzant/a2324dd608c63f1637b1e36a1ffce46d to your computer and use it in GitHub Desktop.
Save rpryzant/a2324dd608c63f1637b1e36a1ffce46d to your computer and use it in GitHub Desktop.
Integrated gradients wrapper
"""
USAGE
model = build_model()
attributor = Attributor(model, target_class=1, tokenizer=tokenizer)
...
# viz = interactive vizualization that you can dump into a file and look at in a web browser
# t2a = map of token to its attribution score
viz, t2a, attrs, y_prob, y_hat = attributor.attr_and_visualize(
batch['input_ids'], batch['labels'])
with open('vizualization.html', 'w') as f:
f.write('\n'.join(viz))
"""
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
from captum.attr import visualization as viz
class Attributor:
def __init__(self, model, target_class, tokenizer):
""" TODO generalize to multiclass """
self.model = model
self.target_class = target_class
self.tokenizer = tokenizer
self.fwd_fn = self.build_forward_fn(target_class)
self.lig = LayerIntegratedGradients(self.fwd_fn, self.model.distilbert.embeddings)
def attribute(self, input_ids):
ref_ids = [[x if x in [101, 102] else 0 for x in input_ids[0]]]
attribution, delta = self.lig.attribute(
inputs=torch.tensor(input_ids).cuda() if CUDA else torch.tensor(input_ids),
baselines=torch.tensor(ref_ids).cuda() if CUDA else torch.tensor(ref_ids),
n_steps=25,
internal_batch_size=5,
return_convergence_delta=True)
attribution_sum = self.summarize(attribution)
return attribution_sum, delta
def attr_and_visualize(self, input_ids, label):
attr_sum, delta = self.attribute(input_ids)
y_prob = self.fwd_fn(input_ids)
pred_class = 1 if y_prob.data[0] > 0.5 else 0
if CUDA:
input_ids = input_ids.cpu().numpy()[0]
label = label.cpu().item()
attr_sum = attr_sum.cpu().numpy()
y_prob = y_prob.cpu().item()
else:
input_ids = input_ids.numpy()[0]
label = label.item()
attr_sum = attr_sum.numpy()
y_prob = y_prob.item()
tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
record = viz.VisualizationDataRecord(
attr_sum,
y_prob,
pred_class,
label,
self.target_class,
attr_sum.sum(),
tokens,
delta)
tok2attr = defaultdict(list)
for tok, attr in zip(tokens, attr_sum):
tok2attr[tok].append(attr)
html = viz.visualize_text([record])
return html.data, tok2attr, attr_sum, y_prob, pred_class
def build_forward_fn(self, label_dim):
def custom_forward(inputs):
preds = self.model(inputs)[0]
return torch.softmax(preds, dim=1)[:, label_dim]
return custom_forward
def summarize(self, attributions):
""" sum across each embedding dim and normalize """
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment