Skip to content

Instantly share code, notes, and snippets.

@davidefiocco
Last active March 31, 2020 22:18
Show Gist options
  • Save davidefiocco/40a1395e895174a4e4d3ed424a5d388a to your computer and use it in GitHub Desktop.
Save davidefiocco/40a1395e895174a4e4d3ed424a5d388a to your computer and use it in GitHub Desktop.
Trying captum interpretation on pretrained sentiment classifier
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"orig_nbformat": 2,
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"npconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": 3,
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"colab": {
"name": "Interpretation.ipynb",
"provenance": [],
"collapsed_sections": []
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "UFESEuEgbUDD",
"colab_type": "text"
},
"source": [
"# Interpretation of BertForSequenceClassification in captum\n",
"\n",
"In this notebook we use Captum to interpret a BERT sentiment classifier finetuned on the imdb dataset https://huggingface.co/lvwerra/bert-imdb "
]
},
{
"cell_type": "code",
"metadata": {
"id": "EJ51JAxHbghp",
"colab_type": "code",
"colab": {}
},
"source": [
"# Install dependencies\n",
"!pip install transformers\n",
"!pip install captum"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "CS9Kaz8ubUDG",
"colab_type": "code",
"colab": {}
},
"source": [
"from transformers import BertTokenizer, BertForSequenceClassification, BertConfig\n",
"from captum.attr import visualization as viz\n",
"from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients\n",
"from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer\n",
"import torch\n",
"import matplotlib.pyplot as plt"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "P1yl1gdvbUDS",
"colab_type": "code",
"colab": {}
},
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3U5XDt1Gb73t",
"colab_type": "code",
"colab": {}
},
"source": [
"# Get model and config files from https://huggingface.co/lvwerra/bert-imdb\n",
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json\n",
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin\n",
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/special_tokens_map.json\n",
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/tokenizer_config.json\n",
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/training_args.bin\n",
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/vocab.txt"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "X-nyyq_tbUDa",
"colab_type": "code",
"colab": {}
},
"source": [
"# load model\n",
"model = BertForSequenceClassification.from_pretrained('./model')\n",
"model.to(device)\n",
"model.eval()\n",
"model.zero_grad()\n",
"\n",
"# load tokenizer\n",
"tokenizer = BertTokenizer.from_pretrained('./model')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JUMsvUOTbUDi",
"colab_type": "code",
"colab": {}
},
"source": [
"def predict(inputs):\n",
" return model(inputs)[0]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "SIbauwGbbUDo",
"colab_type": "code",
"colab": {}
},
"source": [
"ref_token_id = tokenizer.pad_token_id # A token used for generating token reference\n",
"sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.\n",
"cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "mcnTCNUFbUD1",
"colab_type": "code",
"colab": {}
},
"source": [
"def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):\n",
"\n",
" text_ids = tokenizer.encode(text, add_special_tokens=False)\n",
" # construct input token ids\n",
" input_ids = [cls_token_id] + text_ids + [sep_token_id]\n",
" # construct reference token ids \n",
" ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]\n",
"\n",
" return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)\n",
"\n",
"def construct_input_ref_token_type_pair(input_ids, sep_ind=0):\n",
" seq_len = input_ids.size(1)\n",
" token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)\n",
" ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1\n",
" return token_type_ids, ref_token_type_ids\n",
"\n",
"def construct_input_ref_pos_id_pair(input_ids):\n",
" seq_length = input_ids.size(1)\n",
" position_ids = torch.arange(seq_length, dtype=torch.long, device=device)\n",
" # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`\n",
" ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)\n",
"\n",
" position_ids = position_ids.unsqueeze(0).expand_as(input_ids)\n",
" ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)\n",
" return position_ids, ref_position_ids\n",
" \n",
"def construct_attention_mask(input_ids):\n",
" return torch.ones_like(input_ids)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "vhasPia4bUD8",
"colab_type": "code",
"colab": {}
},
"source": [
"def custom_forward(inputs):\n",
" preds = predict(inputs)\n",
" return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "pGwkb1vAbUEA",
"colab_type": "code",
"colab": {}
},
"source": [
"lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "EQlVDaISbUEF",
"colab_type": "code",
"colab": {}
},
"source": [
"# One can test a couple of examples and check that the sentiment classifier is behaving\n",
"text = \"The movie was one of those amazing movies you can't forget.\"\n",
"#text = \"The movie was one of those crappy movies you can't forget.\""
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BtoFctjVbUEM",
"colab_type": "code",
"colab": {}
},
"source": [
"input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)\n",
"token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)\n",
"position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)\n",
"attention_mask = construct_attention_mask(input_ids)\n",
"\n",
"indices = input_ids[0].detach().tolist()\n",
"all_tokens = tokenizer.convert_ids_to_tokens(indices)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "T4vlqBBrbUEY",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "1a54bbcf-955c-4043-ebf6-c4ff9997bbe4"
},
"source": [
"# Check predict output\n",
"predict(input_ids)"
],
"execution_count": 12,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-3.3635, 4.0115]], device='cuda:0', grad_fn=<AddmmBackward>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 12
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "wpNkwy6_bUEd",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "e4f6d662-5753-4baa-cb49-f9a3c35f21d9"
},
"source": [
"# Check output of custom_forward\n",
"custom_forward(input_ids)"
],
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([0.0006], device='cuda:0', grad_fn=<UnsqueezeBackward0>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "YAzBqQlpbUEk",
"colab_type": "code",
"colab": {}
},
"source": [
"attributions, delta = lig.attribute(inputs=input_ids,\n",
" baselines=ref_input_ids,\n",
" return_convergence_delta=True)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dU8SRQFybUEo",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 54
},
"outputId": "bcadcde0-8ce7-4b0e-814c-6b14570f1e2a"
},
"source": [
"score = predict(input_ids)\n",
"\n",
"print('Sentence: ', text)\n",
"print('Sentiment: ' + str(torch.argmax(score[0]).cpu().numpy()) + \\\n",
" ', Probability positive: ' + str(torch.softmax(score, dim = 1)[0][1].cpu().detach().numpy()))"
],
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"text": [
"Sentence: The movie was one of those amazing movies you can't forget.\n",
"Sentiment: 1, Probability positive: 0.99937373\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Hq8R_ZYubUEu",
"colab_type": "code",
"colab": {}
},
"source": [
"def summarize_attributions(attributions):\n",
" attributions = attributions.sum(dim=-1).squeeze(0)\n",
" attributions = attributions / torch.norm(attributions)\n",
" return attributions"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3q7xXwRrbUEx",
"colab_type": "code",
"colab": {}
},
"source": [
"attributions_sum = summarize_attributions(attributions)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0ZF0RmZ4bUE1",
"colab_type": "code",
"colab": {}
},
"source": [
"# storing couple samples in an array for visualization purposes\n",
"score_vis = viz.VisualizationDataRecord(attributions_sum,\n",
" torch.softmax(score, dim = 1)[0][0],\n",
" torch.argmax(torch.softmax(score, dim = 1)[0]),\n",
" 0,\n",
" text,\n",
" attributions_sum.sum(), \n",
" all_tokens,\n",
" delta)\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-gAojuO6ody0",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 131
},
"outputId": "6875e039-c494-4b0f-ec78-ff192daf0918"
},
"source": [
"print('\\033[1m', 'Visualization For Score', '\\033[0m')\n",
"viz.visualize_text([score_vis])"
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[1m Visualization For Score \u001b[0m\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<table width: 100%><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>0</b></text></td><td><text style=\"padding-right:2em\"><b>1 (0.00)</b></text></td><td><text style=\"padding-right:2em\"><b>The movie was one of those amazing movies you can't forget.</b></text></td><td><text style=\"padding-right:2em\"><b>-0.72</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> The </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(120, 75%, 71%); opacity:1.0; line-height:1.75\"><font color=\"black\"> one </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> of </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75\"><font color=\"black\"> those </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> amazing </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movies </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75\"><font color=\"black\"> you </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> can </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ' </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> t </font></mark><mark style=\"background-color: hsl(120, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> forget </font></mark><mark style=\"background-color: hsl(0, 75%, 71%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ItXD4N9FogZu",
"colab_type": "text"
},
"source": [
"The visualization is clearly meaningless! :(\n"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment