Skip to content

Instantly share code, notes, and snippets.

@davidefiocco
Last active June 28, 2023 11:10
Show Gist options
  • Save davidefiocco/3e1a0ed030792230a33c726c61f6b3a5 to your computer and use it in GitHub Desktop.
Save davidefiocco/3e1a0ed030792230a33c726c61f6b3a5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 2,
"metadata": {
"language_info": {
"name": "python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"version": "3.7.4-final"
},
"orig_nbformat": 2,
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"npconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": 3,
"kernelspec": {
"name": "python37464bitbaseconda591eac30377d4dc3af76304e9e0933b9",
"display_name": "Python 3.7.4 64-bit ('base': conda)"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Interpretation of BertForSequenceClassification in captum"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": "I0304 15:15:00.931115 12960 file_utils.py:41] PyTorch version 1.4.0 available.\n"
}
],
"source": [
"from transformers import BertTokenizer, BertForSequenceClassification, BertConfig\n",
"\n",
"from captum.attr import visualization as viz\n",
"from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients\n",
"from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer\n",
"\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": "I0304 15:15:05.666696 12960 configuration_utils.py:254] loading configuration file ../model/config.json\nI0304 15:15:05.669722 12960 configuration_utils.py:292] Model config BertConfig {\n \"architectures\": [\n \"BertForSequenceClassification\"\n ],\n \"attention_probs_dropout_prob\": 0.1,\n \"bos_token_id\": null,\n \"do_sample\": false,\n \"eos_token_ids\": null,\n \"finetuning_task\": \"cola\",\n \"hidden_act\": \"gelu\",\n \"hidden_dropout_prob\": 0.1,\n \"hidden_size\": 768,\n \"id2label\": {\n \"0\": \"LABEL_0\",\n \"1\": \"LABEL_1\"\n },\n \"initializer_range\": 0.02,\n \"intermediate_size\": 3072,\n \"is_decoder\": false,\n \"label2id\": {\n \"LABEL_0\": 0,\n \"LABEL_1\": 1\n },\n \"layer_norm_eps\": 1e-12,\n \"length_penalty\": 1.0,\n \"max_length\": 20,\n \"max_position_embeddings\": 512,\n \"model_type\": \"bert\",\n \"num_attention_heads\": 12,\n \"num_beams\": 1,\n \"num_hidden_layers\": 12,\n \"num_labels\": 2,\n \"num_return_sequences\": 1,\n \"output_attentions\": false,\n \"output_hidden_states\": false,\n \"output_past\": true,\n \"pad_token_id\": null,\n \"pruned_heads\": {},\n \"repetition_penalty\": 1.0,\n \"temperature\": 1.0,\n \"top_k\": 50,\n \"top_p\": 1.0,\n \"torchscript\": false,\n \"type_vocab_size\": 2,\n \"use_bfloat16\": false,\n \"vocab_size\": 31116\n}\n\nI0304 15:15:05.673698 12960 modeling_utils.py:459] loading weights file ../model/pytorch_model.bin\nI0304 15:15:09.035054 12960 tokenization_utils.py:417] Model name '../model/' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). Assuming '../model/' is a path, a model identifier, or url to a directory containing tokenizer files.\nI0304 15:15:09.048053 12960 tokenization_utils.py:446] Didn't find file ../model/added_tokens.json. We won't load it.\nI0304 15:15:09.052233 12960 tokenization_utils.py:499] loading file ../model/vocab.txt\nI0304 15:15:09.056055 12960 tokenization_utils.py:499] loading file None\nI0304 15:15:09.058056 12960 tokenization_utils.py:499] loading file ../model/special_tokens_map.json\nI0304 15:15:09.060055 12960 tokenization_utils.py:499] loading file ../model/tokenizer_config.json\n"
}
],
"source": [
"\n",
"# load model\n",
"model = BertForSequenceClassification.from_pretrained('../model/')\n",
"model.to(device)\n",
"model.eval()\n",
"model.zero_grad()\n",
"\n",
"# load tokenizer\n",
"tokenizer = BertTokenizer.from_pretrained('../model/')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def predict(inputs):\n",
" return model(inputs)[0]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"ref_token_id = tokenizer.pad_token_id # A token used for generating token reference\n",
"sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.\n",
"cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):\n",
"\n",
" text_ids = tokenizer.encode(text, add_special_tokens=False)\n",
" # construct input token ids\n",
" input_ids = [cls_token_id] + text_ids + [sep_token_id]\n",
" # construct reference token ids \n",
" ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]\n",
"\n",
" return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)\n",
"\n",
"def construct_input_ref_token_type_pair(input_ids, sep_ind=0):\n",
" seq_len = input_ids.size(1)\n",
" token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)\n",
" ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1\n",
" return token_type_ids, ref_token_type_ids\n",
"\n",
"def construct_input_ref_pos_id_pair(input_ids):\n",
" seq_length = input_ids.size(1)\n",
" position_ids = torch.arange(seq_length, dtype=torch.long, device=device)\n",
" # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`\n",
" ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)\n",
"\n",
" position_ids = position_ids.unsqueeze(0).expand_as(input_ids)\n",
" ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)\n",
" return position_ids, ref_position_ids\n",
" \n",
"def construct_attention_mask(input_ids):\n",
" return torch.ones_like(input_ids)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def custom_forward(inputs):\n",
" preds = predict(inputs)\n",
" return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"text = \"These tests do not work as expected.\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)\n",
"token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)\n",
"position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)\n",
"attention_mask = construct_attention_mask(input_ids)\n",
"\n",
"indices = input_ids[0].detach().tolist()\n",
"all_tokens = tokenizer.convert_ids_to_tokens(indices)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "(tensor([[-3.4676, 3.5508]], grad_fn=<AddmmBackward>),)"
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model(input_ids)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "tensor([[-3.4676, 3.5508]], grad_fn=<AddmmBackward>)"
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict(input_ids)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "tensor([0.0009], grad_fn=<UnsqueezeBackward0>)"
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"custom_forward(input_ids)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"attributions, delta = lig.attribute(inputs=input_ids,\n",
" baselines=ref_input_ids,\n",
" return_convergence_delta=True)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Question:These tests do not work as expected.\nPredicted Answer: 1, prob ungrammatical: 0.0008944187\n"
}
],
"source": [
"score = predict(input_ids)\n",
"\n",
"print('Question: ', text)\n",
"print('Predicted Answer: ' + str(torch.argmax(score[0]).numpy()) + ', prob ungrammatical: ' + str(torch.softmax(score, dim = 1)[0][0].detach().numpy()))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def summarize_attributions(attributions):\n",
" attributions = attributions.sum(dim=-1).squeeze(0)\n",
" attributions = attributions / torch.norm(attributions)\n",
" return attributions"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"attributions_sum = summarize_attributions(attributions)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Visualization For Score\n"
},
{
"data": {
"text/html": "<table width: 100%><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>0</b></text></td><td><text style=\"padding-right:2em\"><b>1 (0.00)</b></text></td><td><text style=\"padding-right:2em\"><b>These tests do not work as expected.</b></text></td><td><text style=\"padding-right:2em\"><b>0.73</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75\"><font color=\"black\"> These </font></mark><mark style=\"background-color: hsl(120, 75%, 68%); opacity:1.0; line-height:1.75\"><font color=\"black\"> tests </font></mark><mark style=\"background-color: hsl(120, 75%, 90%); opacity:1.0; line-height:1.75\"><font color=\"black\"> do </font></mark><mark style=\"background-color: hsl(120, 75%, 86%); opacity:1.0; line-height:1.75\"><font color=\"black\"> not </font></mark><mark style=\"background-color: hsl(120, 75%, 80%); opacity:1.0; line-height:1.75\"><font color=\"black\"> work </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> as </font></mark><mark style=\"background-color: hsl(0, 75%, 88%); opacity:1.0; line-height:1.75\"><font color=\"black\"> expected </font></mark><mark style=\"background-color: hsl(0, 75%, 83%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>",
"text/plain": "<IPython.core.display.HTML object>"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# storing couple samples in an array for visualization purposes\n",
"score_vis = viz.VisualizationDataRecord(\n",
" attributions_sum,\n",
" torch.softmax(score, dim = 1)[0][0],\n",
" torch.argmax(torch.softmax(score, dim = 1)[0]),\n",
" 0,\n",
" text,\n",
" attributions_sum.sum(), \n",
" all_tokens,\n",
" delta)\n",
"\n",
"print('\\033[1m', 'Visualization For Score', '\\033[0m')\n",
"viz.visualize_text([score_vis])"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment