Skip to content

Instantly share code, notes, and snippets.

@avidale
Last active May 25, 2023 21:23
Show Gist options
  • Save avidale/85f6b3d294f9ff400ef76f2cc7ec559e to your computer and use it in GitHub Desktop.
Save avidale/85f6b3d294f9ff400ef76f2cc7ec559e to your computer and use it in GitHub Desktop.
Translation_Attention.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Translation_Attention.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMQTlNbHvmCPppjkpZ8N02x",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/avidale/85f6b3d294f9ff400ef76f2cc7ec559e/fsmtforconditionalgeneration.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mas_r0htX0RQ"
},
"source": [
"Задача блокнота - показать, какое слово переводится с помощью какого при машинном переводе. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "WyeyufN2OL-n"
},
"source": [
"%%capture\r\n",
"!pip install transformers;"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "U2pUw2_CP489"
},
"source": [
"from transformers import FSMTForConditionalGeneration, FSMTTokenizer\r\n",
"mname = \"facebook/wmt19-en-ru\"\r\n",
"model = FSMTForConditionalGeneration.from_pretrained(mname)\r\n",
"tokenizer = FSMTTokenizer.from_pretrained(mname)"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "I9ScqHf_W4L4"
},
"source": [
"import pandas as pd\r\n",
"import torch"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "pJJiSf0xW8pU"
},
"source": [
"Переводим английскую сразу на русский"
]
},
{
"cell_type": "code",
"metadata": {
"id": "h0INt_iqPMmP",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7bbffa7d-4023-4f42-ea1b-dcb20b18f20c"
},
"source": [
"inputs = \"Machine learning is great, isn't it?\"\r\n",
"input_ids = tokenizer.encode(inputs, return_tensors=\"pt\")\r\n",
"with torch.no_grad():\r\n",
" outputs = model.generate(input_ids)\r\n",
"decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)\r\n",
"print(decoded) # Машинное обучение - это здорово, не так ли?"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"Машинное обучение - это здорово, не так ли?\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2sftssZJW-n6"
},
"source": [
"Ещё раз прогоняем модель, усредняя аттеншны со всех слоёв. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xIvpirVBWzx0",
"outputId": "a342b0e9-630c-44b5-f107-358c5a1f7d08"
},
"source": [
"with torch.no_grad():\r\n",
" out = model.base_model.forward(\r\n",
" input_ids=input_ids, decoder_input_ids=outputs, output_attentions=True, return_dict=True, use_cache=False,\r\n",
" )\r\n",
"cross = torch.stack(out.cross_attentions).mean(0)\r\n",
"print(cross.shape) # (batch_size, num_heads, sequence_length, sequence_length)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"torch.Size([1, 16, 14, 11])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gaqpLU3IXJ2I"
},
"source": [
"Усредняем атеншн по всем головам, обнуляем аттеншн к EOS, и перенормируем. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LNmDBV-kXCDa",
"outputId": "defe422b-f44e-4cc7-f6f6-62b20aedca67"
},
"source": [
"# average over all heads\r\n",
"mean_attention = cross.mean(dim=1).squeeze(0).numpy()\r\n",
"print(mean_attention.shape)\r\n",
"# remove attention to the end of sentence, which is dominating\r\n",
"mean_attention[:, -1] = 0\r\n",
"mean_attention = (mean_attention.T / mean_attention.sum(axis=1)).T"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"(14, 11)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2qlO5h-zXQ4X"
},
"source": [
"Специальный костыль для восстановления английских токенов"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xFhQOuGVXO2s"
},
"source": [
"en_vocab = {v: k for k, v in tokenizer.encoder.items()}\r\n",
"def detokenize(ids):\r\n",
" return [en_vocab.get(idx) for idx in ids]"
],
"execution_count": 9,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "TX7zskroXWcM"
},
"source": [
"Выводим атеншн в табличку"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 462
},
"id": "23a_tFAhXUkh",
"outputId": "f6bb83e1-dc1d-443f-c4b0-396b2eea3a7e"
},
"source": [
"df = pd.DataFrame(\r\n",
" mean_attention, \r\n",
" index=tokenizer.convert_ids_to_tokens(outputs[0]),\r\n",
" columns=detokenize(input_ids[0].numpy())\r\n",
")\r\n",
"df.round(2)"
],
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Mach</th>\n",
" <th>ine&lt;/w&gt;</th>\n",
" <th>learning&lt;/w&gt;</th>\n",
" <th>is&lt;/w&gt;</th>\n",
" <th>great&lt;/w&gt;</th>\n",
" <th>,&lt;/w&gt;</th>\n",
" <th>isn&lt;/w&gt;</th>\n",
" <th>&amp;apos;t&lt;/w&gt;</th>\n",
" <th>it&lt;/w&gt;</th>\n",
" <th>?&lt;/w&gt;</th>\n",
" <th>&lt;/s&gt;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>&lt;/s&gt;</th>\n",
" <td>0.22</td>\n",
" <td>0.10</td>\n",
" <td>0.16</td>\n",
" <td>0.09</td>\n",
" <td>0.10</td>\n",
" <td>0.09</td>\n",
" <td>0.04</td>\n",
" <td>0.04</td>\n",
" <td>0.01</td>\n",
" <td>0.15</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Ма</th>\n",
" <td>0.54</td>\n",
" <td>0.18</td>\n",
" <td>0.16</td>\n",
" <td>0.03</td>\n",
" <td>0.02</td>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" <td>0.03</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>шин</th>\n",
" <td>0.33</td>\n",
" <td>0.19</td>\n",
" <td>0.34</td>\n",
" <td>0.04</td>\n",
" <td>0.04</td>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.03</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ное&lt;/w&gt;</th>\n",
" <td>0.24</td>\n",
" <td>0.14</td>\n",
" <td>0.44</td>\n",
" <td>0.05</td>\n",
" <td>0.04</td>\n",
" <td>0.02</td>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.04</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>обучение&lt;/w&gt;</th>\n",
" <td>0.04</td>\n",
" <td>0.03</td>\n",
" <td>0.25</td>\n",
" <td>0.24</td>\n",
" <td>0.28</td>\n",
" <td>0.04</td>\n",
" <td>0.03</td>\n",
" <td>0.02</td>\n",
" <td>0.02</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>-&lt;/w&gt;</th>\n",
" <td>0.02</td>\n",
" <td>0.01</td>\n",
" <td>0.06</td>\n",
" <td>0.35</td>\n",
" <td>0.33</td>\n",
" <td>0.06</td>\n",
" <td>0.03</td>\n",
" <td>0.03</td>\n",
" <td>0.02</td>\n",
" <td>0.09</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>это&lt;/w&gt;</th>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.04</td>\n",
" <td>0.29</td>\n",
" <td>0.40</td>\n",
" <td>0.06</td>\n",
" <td>0.04</td>\n",
" <td>0.04</td>\n",
" <td>0.04</td>\n",
" <td>0.07</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>здорово&lt;/w&gt;</th>\n",
" <td>0.02</td>\n",
" <td>0.01</td>\n",
" <td>0.03</td>\n",
" <td>0.12</td>\n",
" <td>0.31</td>\n",
" <td>0.17</td>\n",
" <td>0.09</td>\n",
" <td>0.06</td>\n",
" <td>0.04</td>\n",
" <td>0.14</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>,&lt;/w&gt;</th>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.02</td>\n",
" <td>0.03</td>\n",
" <td>0.06</td>\n",
" <td>0.25</td>\n",
" <td>0.28</td>\n",
" <td>0.19</td>\n",
" <td>0.04</td>\n",
" <td>0.12</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>не&lt;/w&gt;</th>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.03</td>\n",
" <td>0.05</td>\n",
" <td>0.04</td>\n",
" <td>0.35</td>\n",
" <td>0.28</td>\n",
" <td>0.08</td>\n",
" <td>0.12</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>так&lt;/w&gt;</th>\n",
" <td>0.02</td>\n",
" <td>0.01</td>\n",
" <td>0.02</td>\n",
" <td>0.02</td>\n",
" <td>0.04</td>\n",
" <td>0.05</td>\n",
" <td>0.27</td>\n",
" <td>0.22</td>\n",
" <td>0.14</td>\n",
" <td>0.21</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ли&lt;/w&gt;</th>\n",
" <td>0.03</td>\n",
" <td>0.01</td>\n",
" <td>0.04</td>\n",
" <td>0.04</td>\n",
" <td>0.05</td>\n",
" <td>0.04</td>\n",
" <td>0.21</td>\n",
" <td>0.11</td>\n",
" <td>0.14</td>\n",
" <td>0.33</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>?&lt;/w&gt;</th>\n",
" <td>0.04</td>\n",
" <td>0.03</td>\n",
" <td>0.04</td>\n",
" <td>0.03</td>\n",
" <td>0.04</td>\n",
" <td>0.05</td>\n",
" <td>0.06</td>\n",
" <td>0.05</td>\n",
" <td>0.05</td>\n",
" <td>0.62</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>&lt;/s&gt;</th>\n",
" <td>0.11</td>\n",
" <td>0.06</td>\n",
" <td>0.10</td>\n",
" <td>0.04</td>\n",
" <td>0.07</td>\n",
" <td>0.04</td>\n",
" <td>0.06</td>\n",
" <td>0.05</td>\n",
" <td>0.08</td>\n",
" <td>0.39</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Mach ine</w> learning</w> ... it</w> ?</w> </s>\n",
"</s> 0.22 0.10 0.16 ... 0.01 0.15 0.0\n",
"Ма 0.54 0.18 0.16 ... 0.00 0.03 0.0\n",
"шин 0.33 0.19 0.34 ... 0.01 0.03 0.0\n",
"ное</w> 0.24 0.14 0.44 ... 0.01 0.04 0.0\n",
"обучение</w> 0.04 0.03 0.25 ... 0.02 0.05 0.0\n",
"-</w> 0.02 0.01 0.06 ... 0.02 0.09 0.0\n",
"это</w> 0.01 0.01 0.04 ... 0.04 0.07 0.0\n",
"здорово</w> 0.02 0.01 0.03 ... 0.04 0.14 0.0\n",
",</w> 0.01 0.01 0.02 ... 0.04 0.12 0.0\n",
"не</w> 0.01 0.01 0.01 ... 0.08 0.12 0.0\n",
"так</w> 0.02 0.01 0.02 ... 0.14 0.21 0.0\n",
"ли</w> 0.03 0.01 0.04 ... 0.14 0.33 0.0\n",
"?</w> 0.04 0.03 0.04 ... 0.05 0.62 0.0\n",
"</s> 0.11 0.06 0.10 ... 0.08 0.39 0.0\n",
"\n",
"[14 rows x 11 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W2d6FGhZXYmG"
},
"source": [
"Бонус: варианты перевода"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TbUfsKClbMAA",
"outputId": "37adc23f-4530-4636-a428-e1ead366fcbf"
},
"source": [
"with torch.no_grad():\r\n",
" outputs = model.generate(input_ids, num_beams=5, num_return_sequences=5)\r\n",
"for o in outputs:\r\n",
" print(tokenizer.decode(o, skip_special_tokens=True))"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": [
"Машинное обучение - это здорово, не так ли?\n",
"Машинное обучение - это прекрасно, не так ли?\n",
"Машинное обучение - это здорово, не правда ли?\n",
"Машинное обучение - это прекрасно, не правда ли?\n",
"Машинное обучение - это замечательно, не так ли?\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A76IsbTqbWHP",
"outputId": "42b5a6fd-b041-4488-808d-760789fe8789"
},
"source": [
"with torch.no_grad():\r\n",
" outputs = model.generate(input_ids, do_sample=True, num_return_sequences=5, temperature=2.0, max_length=20)\r\n",
"for o in outputs:\r\n",
" print(tokenizer.decode(o, skip_special_tokens=True))"
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": [
"Машинное обучение - замечательная штука, не так ли?\n",
"Машинное обучение это здорово, не так ли?\n",
"Машинное усвоение является отличным способом, не так ли?\n",
"Машинное познание - это прекрасно, не так ли?\n",
"Машинное обучение - большое дело, не так ли?\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "4E-Fblr3bd7S"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment