Last active
May 25, 2023 21:23
-
-
Save avidale/85f6b3d294f9ff400ef76f2cc7ec559e to your computer and use it in GitHub Desktop.
Translation_Attention.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Translation_Attention.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyMQTlNbHvmCPppjkpZ8N02x", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/avidale/85f6b3d294f9ff400ef76f2cc7ec559e/fsmtforconditionalgeneration.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Mas_r0htX0RQ" | |
}, | |
"source": [ | |
"Задача блокнота - показать, какое слово переводится с помощью какого при машинном переводе. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WyeyufN2OL-n" | |
}, | |
"source": [ | |
"%%capture\r\n", | |
"!pip install transformers;" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "U2pUw2_CP489" | |
}, | |
"source": [ | |
"from transformers import FSMTForConditionalGeneration, FSMTTokenizer\r\n", | |
"mname = \"facebook/wmt19-en-ru\"\r\n", | |
"model = FSMTForConditionalGeneration.from_pretrained(mname)\r\n", | |
"tokenizer = FSMTTokenizer.from_pretrained(mname)" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "I9ScqHf_W4L4" | |
}, | |
"source": [ | |
"import pandas as pd\r\n", | |
"import torch" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "pJJiSf0xW8pU" | |
}, | |
"source": [ | |
"Переводим английскую сразу на русский" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "h0INt_iqPMmP", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "7bbffa7d-4023-4f42-ea1b-dcb20b18f20c" | |
}, | |
"source": [ | |
"inputs = \"Machine learning is great, isn't it?\"\r\n", | |
"input_ids = tokenizer.encode(inputs, return_tensors=\"pt\")\r\n", | |
"with torch.no_grad():\r\n", | |
" outputs = model.generate(input_ids)\r\n", | |
"decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)\r\n", | |
"print(decoded) # Машинное обучение - это здорово, не так ли?" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Машинное обучение - это здорово, не так ли?\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "2sftssZJW-n6" | |
}, | |
"source": [ | |
"Ещё раз прогоняем модель, усредняя аттеншны со всех слоёв. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "xIvpirVBWzx0", | |
"outputId": "a342b0e9-630c-44b5-f107-358c5a1f7d08" | |
}, | |
"source": [ | |
"with torch.no_grad():\r\n", | |
" out = model.base_model.forward(\r\n", | |
" input_ids=input_ids, decoder_input_ids=outputs, output_attentions=True, return_dict=True, use_cache=False,\r\n", | |
" )\r\n", | |
"cross = torch.stack(out.cross_attentions).mean(0)\r\n", | |
"print(cross.shape) # (batch_size, num_heads, sequence_length, sequence_length)" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([1, 16, 14, 11])\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "gaqpLU3IXJ2I" | |
}, | |
"source": [ | |
"Усредняем атеншн по всем головам, обнуляем аттеншн к EOS, и перенормируем. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "LNmDBV-kXCDa", | |
"outputId": "defe422b-f44e-4cc7-f6f6-62b20aedca67" | |
}, | |
"source": [ | |
"# average over all heads\r\n", | |
"mean_attention = cross.mean(dim=1).squeeze(0).numpy()\r\n", | |
"print(mean_attention.shape)\r\n", | |
"# remove attention to the end of sentence, which is dominating\r\n", | |
"mean_attention[:, -1] = 0\r\n", | |
"mean_attention = (mean_attention.T / mean_attention.sum(axis=1)).T" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"(14, 11)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "2qlO5h-zXQ4X" | |
}, | |
"source": [ | |
"Специальный костыль для восстановления английских токенов" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xFhQOuGVXO2s" | |
}, | |
"source": [ | |
"en_vocab = {v: k for k, v in tokenizer.encoder.items()}\r\n", | |
"def detokenize(ids):\r\n", | |
" return [en_vocab.get(idx) for idx in ids]" | |
], | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "TX7zskroXWcM" | |
}, | |
"source": [ | |
"Выводим атеншн в табличку" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 462 | |
}, | |
"id": "23a_tFAhXUkh", | |
"outputId": "f6bb83e1-dc1d-443f-c4b0-396b2eea3a7e" | |
}, | |
"source": [ | |
"df = pd.DataFrame(\r\n", | |
" mean_attention, \r\n", | |
" index=tokenizer.convert_ids_to_tokens(outputs[0]),\r\n", | |
" columns=detokenize(input_ids[0].numpy())\r\n", | |
")\r\n", | |
"df.round(2)" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Mach</th>\n", | |
" <th>ine</w></th>\n", | |
" <th>learning</w></th>\n", | |
" <th>is</w></th>\n", | |
" <th>great</w></th>\n", | |
" <th>,</w></th>\n", | |
" <th>isn</w></th>\n", | |
" <th>&apos;t</w></th>\n", | |
" <th>it</w></th>\n", | |
" <th>?</w></th>\n", | |
" <th></s></th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th></s></th>\n", | |
" <td>0.22</td>\n", | |
" <td>0.10</td>\n", | |
" <td>0.16</td>\n", | |
" <td>0.09</td>\n", | |
" <td>0.10</td>\n", | |
" <td>0.09</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.15</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>Ма</th>\n", | |
" <td>0.54</td>\n", | |
" <td>0.18</td>\n", | |
" <td>0.16</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.02</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.00</td>\n", | |
" <td>0.00</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>шин</th>\n", | |
" <td>0.33</td>\n", | |
" <td>0.19</td>\n", | |
" <td>0.34</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>ное</w></th>\n", | |
" <td>0.24</td>\n", | |
" <td>0.14</td>\n", | |
" <td>0.44</td>\n", | |
" <td>0.05</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.02</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>обучение</w></th>\n", | |
" <td>0.04</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.25</td>\n", | |
" <td>0.24</td>\n", | |
" <td>0.28</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.02</td>\n", | |
" <td>0.02</td>\n", | |
" <td>0.05</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>-</w></th>\n", | |
" <td>0.02</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.06</td>\n", | |
" <td>0.35</td>\n", | |
" <td>0.33</td>\n", | |
" <td>0.06</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.02</td>\n", | |
" <td>0.09</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>это</w></th>\n", | |
" <td>0.01</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.29</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0.06</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.07</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>здорово</w></th>\n", | |
" <td>0.02</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.12</td>\n", | |
" <td>0.31</td>\n", | |
" <td>0.17</td>\n", | |
" <td>0.09</td>\n", | |
" <td>0.06</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.14</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>,</w></th>\n", | |
" <td>0.01</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.02</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.06</td>\n", | |
" <td>0.25</td>\n", | |
" <td>0.28</td>\n", | |
" <td>0.19</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.12</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>не</w></th>\n", | |
" <td>0.01</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.05</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.35</td>\n", | |
" <td>0.28</td>\n", | |
" <td>0.08</td>\n", | |
" <td>0.12</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>так</w></th>\n", | |
" <td>0.02</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.02</td>\n", | |
" <td>0.02</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.05</td>\n", | |
" <td>0.27</td>\n", | |
" <td>0.22</td>\n", | |
" <td>0.14</td>\n", | |
" <td>0.21</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>ли</w></th>\n", | |
" <td>0.03</td>\n", | |
" <td>0.01</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.05</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.21</td>\n", | |
" <td>0.11</td>\n", | |
" <td>0.14</td>\n", | |
" <td>0.33</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>?</w></th>\n", | |
" <td>0.04</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.03</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.05</td>\n", | |
" <td>0.06</td>\n", | |
" <td>0.05</td>\n", | |
" <td>0.05</td>\n", | |
" <td>0.62</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th></s></th>\n", | |
" <td>0.11</td>\n", | |
" <td>0.06</td>\n", | |
" <td>0.10</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.07</td>\n", | |
" <td>0.04</td>\n", | |
" <td>0.06</td>\n", | |
" <td>0.05</td>\n", | |
" <td>0.08</td>\n", | |
" <td>0.39</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Mach ine</w> learning</w> ... it</w> ?</w> </s>\n", | |
"</s> 0.22 0.10 0.16 ... 0.01 0.15 0.0\n", | |
"Ма 0.54 0.18 0.16 ... 0.00 0.03 0.0\n", | |
"шин 0.33 0.19 0.34 ... 0.01 0.03 0.0\n", | |
"ное</w> 0.24 0.14 0.44 ... 0.01 0.04 0.0\n", | |
"обучение</w> 0.04 0.03 0.25 ... 0.02 0.05 0.0\n", | |
"-</w> 0.02 0.01 0.06 ... 0.02 0.09 0.0\n", | |
"это</w> 0.01 0.01 0.04 ... 0.04 0.07 0.0\n", | |
"здорово</w> 0.02 0.01 0.03 ... 0.04 0.14 0.0\n", | |
",</w> 0.01 0.01 0.02 ... 0.04 0.12 0.0\n", | |
"не</w> 0.01 0.01 0.01 ... 0.08 0.12 0.0\n", | |
"так</w> 0.02 0.01 0.02 ... 0.14 0.21 0.0\n", | |
"ли</w> 0.03 0.01 0.04 ... 0.14 0.33 0.0\n", | |
"?</w> 0.04 0.03 0.04 ... 0.05 0.62 0.0\n", | |
"</s> 0.11 0.06 0.10 ... 0.08 0.39 0.0\n", | |
"\n", | |
"[14 rows x 11 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "W2d6FGhZXYmG" | |
}, | |
"source": [ | |
"Бонус: варианты перевода" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "TbUfsKClbMAA", | |
"outputId": "37adc23f-4530-4636-a428-e1ead366fcbf" | |
}, | |
"source": [ | |
"with torch.no_grad():\r\n", | |
" outputs = model.generate(input_ids, num_beams=5, num_return_sequences=5)\r\n", | |
"for o in outputs:\r\n", | |
" print(tokenizer.decode(o, skip_special_tokens=True))" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Машинное обучение - это здорово, не так ли?\n", | |
"Машинное обучение - это прекрасно, не так ли?\n", | |
"Машинное обучение - это здорово, не правда ли?\n", | |
"Машинное обучение - это прекрасно, не правда ли?\n", | |
"Машинное обучение - это замечательно, не так ли?\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "A76IsbTqbWHP", | |
"outputId": "42b5a6fd-b041-4488-808d-760789fe8789" | |
}, | |
"source": [ | |
"with torch.no_grad():\r\n", | |
" outputs = model.generate(input_ids, do_sample=True, num_return_sequences=5, temperature=2.0, max_length=20)\r\n", | |
"for o in outputs:\r\n", | |
" print(tokenizer.decode(o, skip_special_tokens=True))" | |
], | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Машинное обучение - замечательная штука, не так ли?\n", | |
"Машинное обучение это здорово, не так ли?\n", | |
"Машинное усвоение является отличным способом, не так ли?\n", | |
"Машинное познание - это прекрасно, не так ли?\n", | |
"Машинное обучение - большое дело, не так ли?\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4E-Fblr3bd7S" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment