Last active March 31, 2020
Trying captum interpretation on pretrained sentiment classifier
"cells": [
# Interpretation of BertForSequenceClassification in captum
In this notebook we use Captum to interpret a BERT sentiment classifier finetuned on the imdb dataset
# Install dependencies
!pip install transformers
!pip install captum
"from transformers import BertTokenizer, BertForSequenceClassification, BertConfig\n",
"from captum.attr import visualization as viz\n",
"from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients\n",
"from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer\n",
"import torch\n",
"import matplotlib.pyplot as plt"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
"# Get model and config files from\n",
"!wget -P ./model\n",
"!wget -P ./model\n",
"!wget -P ./model\n",
"!wget -P ./model\n",
"!wget -P ./model\n",
"!wget -P ./model"
"# load model\n",
"model = BertForSequenceClassification.from_pretrained('./model')\n",
"# load tokenizer\n",
"tokenizer = BertTokenizer.from_pretrained('./model')"
"def predict(inputs):\n",
" return model(inputs)[0]"
"ref_token_id = tokenizer.pad_token_id # A token used for generating token reference\n",
"sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.\n",
"cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence"
"def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):\n",
" text_ids = tokenizer.encode(text, add_special_tokens=False)\n",
" # construct input token ids\n",
" input_ids = [cls_token_id] + text_ids + [sep_token_id]\n",
" # construct reference token ids \n",
" ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]\n",
" return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)\n",
"def construct_input_ref_token_type_pair(input_ids, sep_ind=0):\n",
" seq_len = input_ids.size(1)\n",
" token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)\n",
" ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1\n",
" return token_type_ids, ref_token_type_ids\n",
"def construct_input_ref_pos_id_pair(input_ids):\n",
" seq_length = input_ids.size(1)\n",
" position_ids = torch.arange(seq_length, dtype=torch.long, device=device)\n",
" # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`\n",
" ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)\n",
" position_ids = position_ids.unsqueeze(0).expand_as(input_ids)\n",
" ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)\n",
" return position_ids, ref_position_ids\n",
" \n",
"def construct_attention_mask(input_ids):\n",
" return torch.ones_like(input_ids)"
"def custom_forward(inputs):\n",
" preds = predict(inputs)\n",
" return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)"
"execution_count": 0,
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)
"# One can test a couple of examples and check that the sentiment classifier is behaving\n",
"text = \"The movie was one of those amazing movies you can't forget.\"\n",
"#text = \"The movie was one of those crappy movies you can't forget.\""
"execution_count": 0,
"input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)\n",
"token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)\n",
"position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)\n",
"attention_mask = construct_attention_mask(input_ids)\n",
"indices = input_ids[0].detach().tolist()\n",
"all_tokens = tokenizer.convert_ids_to_tokens(indices)"
"# Check predict output\n",
tensor([[-3.3635, 4.0115]], device='cuda:0', grad_fn=<AddmmBackward>)
"# Check output of custom_forward\n",
tensor([0.0006], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
"attributions, delta = lig.attribute(inputs=input_ids,\n",
" baselines=ref_input_ids,\n",
" return_convergence_delta=True)"
"score = predict(input_ids)\n",
"print('Sentence: ', text)\n",
"print('Sentiment: ' + str(torch.argmax(score[0]).cpu().numpy()) + \\\n",
" ', Probability positive: ' + str(torch.softmax(score, dim = 1)[0][1].cpu().detach().numpy()))"
"Sentence: The movie was one of those amazing movies you can't forget.\n",
"def summarize_attributions(attributions):\n",
" attributions = attributions.sum(dim=-1).squeeze(0)\n",
" attributions = attributions / torch.norm(attributions)\n",
" return attributions"
attributions_sum = summarize_attributions(attributions)
"# storing couple samples in an array for visualization purposes\n",
"score_vis = viz.VisualizationDataRecord(attributions_sum,\n",
" torch.softmax(score, dim = 1)[0][0],\n",
" torch.argmax(torch.softmax(score, dim = 1)[0]),\n",
" 0,\n",
" text,\n",
" attributions_sum.sum(), \n",
" all_tokens,\n",
" delta)\n"
print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])
<table width: 100%><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style="padding-right:2em"><b>0</b></text></td><td><text style="padding-right:2em"><b>1 (0.00)</b></text></td><td><text style="padding-right:2em"><b>The movie was one of those amazing movies you can't forget.</b></text></td><td><text style="padding-right:2em"><b>-0.72</b></text></td><td><mark style="background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75"><font color="black"> [CLS] </font></mark><mark style="background-color: hsl(120, 75%, 92%); opacity:1.0; line-height:1.75"><font color="black"> The </font></mark><mark style="background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75"><font color="black"> movie </font></mark><mark style="background-color: hsl(0, 75%, 96%); opacity:1.0; line-height:1.75"><font color="black"> was </font></mark><mark style="background-color: hsl(120, 75%, 71%); opacity:1.0; line-height:1.75"><font color="black"> one </font></mark><mark style="background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75"><font color="black"> of </font></mark><mark style="background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75"><font color="black"> those </font></mark><mark style="background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75"><font color="black"> amazing </font></mark><mark style="background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75"><font color="black"> movies </font></mark><mark style="background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75"><font color="black"> you </font></mark><mark style="background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75"><font color="black"> can </font></mark><mark style="background-color: hsl(0, 75%, 96%); opacity:1.0; line-height:1.75"><font color="black"> ' </font></mark><mark style="background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75"><font color="black"> t </font></mark><mark style="background-color: hsl(120, 75%, 98%); opacity:1.0; line-height:1.75"><font color="black"> forget </font></mark><mark style="background-color: hsl(0, 75%, 71%); opacity:1.0; line-height:1.75"><font color="black"> . </font></mark><mark style="background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75"><font color="black"> [SEP] </font></mark></td><tr></table>
The visualization is clearly meaningless! :(
