Skip to content

Instantly share code, notes, and snippets.

@MachineLearningIsEasy
Created October 12, 2020 11:05
Show Gist options
  • Save MachineLearningIsEasy/8ec7c3a100426411e0c04fea5f332beb to your computer and use it in GitHub Desktop.
Save MachineLearningIsEasy/8ec7c3a100426411e0c04fea5f332beb to your computer and use it in GitHub Desktop.
Create and train seq2seq model with pytorch
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
},
"colab": {
"name": "Sequence to Sequence Learning with Neural Networks.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "pIyzObvlX1_2"
},
"source": [
"# Seq2seq"
]
},
{
"cell_type": "code",
"metadata": {
"id": "aWtInJodCzi1"
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"from torchtext.datasets import Multi30k\n",
"from torchtext.data import Field, BucketIterator\n",
"\n",
"import spacy\n",
"import numpy as np\n",
"\n",
"import random\n",
"import math\n",
"import time"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "u2HMKhWfCzi_"
},
"source": [
"Зафиксируем \"случайности\""
]
},
{
"cell_type": "code",
"metadata": {
"id": "l-U4opWUCzjA"
},
"source": [
"SEED = 42\n",
"\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)\n",
"torch.manual_seed(SEED)\n",
"torch.cuda.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True"
],
"execution_count": 33,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z4AYyH6wCzjH"
},
"source": [
"Выполним токенизацию:\n",
"\n",
"\"good morning!\" --> [\"good\", \"morning\", \"!\"]"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Qg-239m1ZpNx",
"outputId": "c87ed451-b366-428b-f2b7-0e40e8074409",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 496
}
},
"source": [
"!python -m spacy download en"
],
"execution_count": 34,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: en_core_web_sm==2.2.5 from https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.5/en_core_web_sm-2.2.5.tar.gz#egg=en_core_web_sm==2.2.5 in /usr/local/lib/python3.6/dist-packages (2.2.5)\n",
"Requirement already satisfied: spacy>=2.2.2 in /usr/local/lib/python3.6/dist-packages (from en_core_web_sm==2.2.5) (2.2.4)\n",
"Requirement already satisfied: blis<0.5.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (0.4.1)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (3.0.2)\n",
"Requirement already satisfied: thinc==7.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (7.4.0)\n",
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.2)\n",
"Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.2)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (2.23.0)\n",
"Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.1.3)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (50.3.0)\n",
"Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (0.8.0)\n",
"Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.18.5)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (2.0.3)\n",
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (4.41.1)\n",
"Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.0)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (2020.6.20)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.0.4)\n",
"Requirement already satisfied: importlib-metadata>=0.20; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->en_core_web_sm==2.2.5) (2.0.0)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < \"3.8\"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.2.0)\n",
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
"You can now load the model via spacy.load('en_core_web_sm')\n",
"\u001b[38;5;2m✔ Linking successful\u001b[0m\n",
"/usr/local/lib/python3.6/dist-packages/en_core_web_sm -->\n",
"/usr/local/lib/python3.6/dist-packages/spacy/data/en\n",
"You can now load the model via spacy.load('en')\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aB-KBLojZwLc",
"outputId": "dd64f00b-2d9f-40e5-fa30-4babc2a5ed06",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 496
}
},
"source": [
"!python -m spacy download de"
],
"execution_count": 35,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: de_core_news_sm==2.2.5 from https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-2.2.5/de_core_news_sm-2.2.5.tar.gz#egg=de_core_news_sm==2.2.5 in /usr/local/lib/python3.6/dist-packages (2.2.5)\n",
"Requirement already satisfied: spacy>=2.2.2 in /usr/local/lib/python3.6/dist-packages (from de_core_news_sm==2.2.5) (2.2.4)\n",
"Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.1.3)\n",
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (4.41.1)\n",
"Requirement already satisfied: blis<0.5.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (0.4.1)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (3.0.2)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (2.0.3)\n",
"Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.2)\n",
"Requirement already satisfied: thinc==7.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (7.4.0)\n",
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.2)\n",
"Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.0)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (50.3.0)\n",
"Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.18.5)\n",
"Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (0.8.0)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (2.23.0)\n",
"Requirement already satisfied: importlib-metadata>=0.20; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->de_core_news_sm==2.2.5) (2.0.0)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (2020.6.20)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.0.4)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < \"3.8\"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.2.0)\n",
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
"You can now load the model via spacy.load('de_core_news_sm')\n",
"\u001b[38;5;2m✔ Linking successful\u001b[0m\n",
"/usr/local/lib/python3.6/dist-packages/de_core_news_sm -->\n",
"/usr/local/lib/python3.6/dist-packages/spacy/data/de\n",
"You can now load the model via spacy.load('de')\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TghQwcHMCzjI"
},
"source": [
"spacy_de = spacy.load('de')\n",
"spacy_en = spacy.load('en')"
],
"execution_count": 36,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "bV6NIMQVCzjO"
},
"source": [
"Напишем функции для токенизации"
]
},
{
"cell_type": "code",
"metadata": {
"id": "qCPY6hL_CzjQ"
},
"source": [
"def tokenize_de(text):\n",
" \"\"\"\n",
" Tokenizes German text from a string into a list of strings (tokens) and reverses it\n",
" \"\"\"\n",
" return [tok.text for tok in spacy_de.tokenizer(text)][::-1]\n",
"\n",
"def tokenize_en(text):\n",
" \"\"\"\n",
" Tokenizes English text from a string into a list of strings (tokens)\n",
" \"\"\"\n",
" return [tok.text for tok in spacy_en.tokenizer(text)]"
],
"execution_count": 37,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "vydsrfS6CzjV"
},
"source": [
"Создаем объект класса Field [(см. здесь)](https://github.com/pytorch/text/blob/master/torchtext/data/field.py#L61) с настройками предобработки текстов\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ic-64nZ1CzjW"
},
"source": [
"SRC = Field(tokenize = tokenize_de, \n",
" init_token = '<sos>', \n",
" eos_token = '<eos>', \n",
" lower = True)\n",
"\n",
"TRG = Field(tokenize = tokenize_en, \n",
" init_token = '<sos>', \n",
" eos_token = '<eos>', \n",
" lower = True)"
],
"execution_count": 38,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "2yzJqqKtCzjc"
},
"source": [
"\n",
"Загружам датасет [Multi30k dataset](https://github.com/multi30k/dataset). Примерно 30 тысяч строк с предложениями на Английском, Немецком и Французском языках. Средняя длина предложения - 12 слов. \n",
"Разбиваем датасет на датасеты для обучения, тестирования, валидации"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xXxf75GfCzjd"
},
"source": [
"train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), \n",
" fields = (SRC, TRG))"
],
"execution_count": 39,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "__7GZlypCzjn",
"outputId": "6969df8a-5d0d-45b6-d40f-f9edd35f79d3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"print(f\"Датасет для обучения: {len(train_data.examples)}\")\n",
"print(f\"Датасет для валидации: {len(valid_data.examples)}\")\n",
"print(f\"Датасет для тестировки: {len(test_data.examples)}\")"
],
"execution_count": 40,
"outputs": [
{
"output_type": "stream",
"text": [
"Датасет для обучения: 29000\n",
"Датасет для валидации: 1014\n",
"Датасет для тестировки: 1000\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3UNWmLtwCzjx",
"outputId": "487fdc70-d127-4c20-97cf-21f3bb966ca3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 54
}
},
"source": [
"print(vars(train_data.examples[0]))"
],
"execution_count": 41,
"outputs": [
{
"output_type": "stream",
"text": [
"{'src': ['.', 'büsche', 'vieler', 'nähe', 'der', 'in', 'freien', 'im', 'sind', 'männer', 'weiße', 'junge', 'zwei'], 'trg': ['two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.']}\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8B7X2NdnCzj2"
},
"source": [
"Сформируем словари"
]
},
{
"cell_type": "code",
"metadata": {
"id": "vE8KXMvNCzj3"
},
"source": [
"SRC.build_vocab(train_data, min_freq = 2)\n",
"TRG.build_vocab(train_data, min_freq = 2)"
],
"execution_count": 42,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Zvlyw0BaCzkB",
"outputId": "bfdab471-6cfa-49f2-a9ed-681caca4b963",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"print(f\"Словарь состоит из {len(SRC.vocab)} слов\")\n",
"print(f\"Словарь состоит из {len(TRG.vocab)} слов\")"
],
"execution_count": 43,
"outputs": [
{
"output_type": "stream",
"text": [
"Словарь состоит из 7855 слов\n",
"Словарь состоит из 5893 слов\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v38NFsiwCzkF"
},
"source": [
"Создадим итератор для формирования батчей для подачи данных в сеть"
]
},
{
"cell_type": "code",
"metadata": {
"id": "mgqSpOh8CzkG"
},
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
],
"execution_count": 44,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "pGEjWL8NWh6e",
"outputId": "626611d0-12b1-4873-d651-d01552edf8f6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"device"
],
"execution_count": 45,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"metadata": {
"tags": []
},
"execution_count": 45
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "UPVPkOt8CzkK"
},
"source": [
"BATCH_SIZE = 128\n",
"\n",
"train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n",
" (train_data, valid_data, test_data), \n",
" batch_size = BATCH_SIZE, \n",
" device = device)"
],
"execution_count": 46,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "3sYkcJb_CzkQ"
},
"source": [
"## Строим классическую Seq2Seq Model\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nltoT1LwycZZ"
},
"source": [
"### Encoder"
]
},
{
"cell_type": "code",
"metadata": {
"id": "PEc3llCECzkQ"
},
"source": [
"class Encoder(nn.Module):\n",
" def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):\n",
" super().__init__()\n",
" \n",
" self.hid_dim = hid_dim\n",
" self.n_layers = n_layers\n",
" \n",
" self.embedding = nn.Embedding(input_dim, emb_dim)\n",
" \n",
" self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, src):\n",
" \n",
" #src = [src len, batch size]\n",
" \n",
" embedded = self.dropout(self.embedding(src))\n",
" \n",
" #embedded = [src len, batch size, emb dim]\n",
" \n",
" outputs, (hidden, cell) = self.rnn(embedded)\n",
" \n",
" #outputs = [src len, batch size, hid dim * n directions]\n",
" #hidden = [n layers * n directions, batch size, hid dim]\n",
" #cell = [n layers * n directions, batch size, hid dim]\n",
" \n",
" #outputs are always from the top hidden layer\n",
" \n",
" return hidden, cell"
],
"execution_count": 47,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "fpxMfzHBCzkW"
},
"source": [
"### Decoder\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ggXPsdatCzkX"
},
"source": [
"class Decoder(nn.Module):\n",
" def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):\n",
" super().__init__()\n",
" \n",
" self.output_dim = output_dim\n",
" self.hid_dim = hid_dim\n",
" self.n_layers = n_layers\n",
" \n",
" self.embedding = nn.Embedding(output_dim, emb_dim)\n",
" \n",
" self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)\n",
" \n",
" self.fc_out = nn.Linear(hid_dim, output_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, input, hidden, cell):\n",
" \n",
" #input = [batch size]\n",
" #hidden = [n layers * n directions, batch size, hid dim]\n",
" #cell = [n layers * n directions, batch size, hid dim]\n",
" \n",
" #n directions in the decoder will both always be 1, therefore:\n",
" #hidden = [n layers, batch size, hid dim]\n",
" #context = [n layers, batch size, hid dim]\n",
" \n",
" input = input.unsqueeze(0)\n",
" \n",
" #input = [1, batch size]\n",
" \n",
" embedded = self.dropout(self.embedding(input))\n",
" \n",
" #embedded = [1, batch size, emb dim]\n",
" \n",
" output, (hidden, cell) = self.rnn(embedded, (hidden, cell))\n",
" \n",
" #output = [seq len, batch size, hid dim * n directions]\n",
" #hidden = [n layers * n directions, batch size, hid dim]\n",
" #cell = [n layers * n directions, batch size, hid dim]\n",
" \n",
" #seq len and n directions will always be 1 in the decoder, therefore:\n",
" #output = [1, batch size, hid dim]\n",
" #hidden = [n layers, batch size, hid dim]\n",
" #cell = [n layers, batch size, hid dim]\n",
" \n",
" prediction = self.fc_out(output.squeeze(0))\n",
" \n",
" #prediction = [batch size, output dim]\n",
" \n",
" return prediction, hidden, cell"
],
"execution_count": 48,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "7Q0fyB5YCzkb"
},
"source": [
"### Seq2Seq\n",
"\n",
"\n",
"Пример входных/выходных данных\n",
"\n",
"$$\\begin{align*}\n",
"\\text{trg} = [<sos>, &y_1, y_2, y_3, <eos>]\\\\\n",
"\\text{outputs} = [0, &\\hat{y}_1, \\hat{y}_2, \\hat{y}_3, <eos>]\n",
"\\end{align*}$$"
]
},
{
"cell_type": "code",
"metadata": {
"id": "2SWW5MbfCzke"
},
"source": [
"class Seq2Seq(nn.Module):\n",
" def __init__(self, encoder, decoder, device):\n",
" super().__init__()\n",
" \n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.device = device\n",
" \n",
" assert encoder.hid_dim == decoder.hid_dim, \\\n",
" \"Hidden dimensions of encoder and decoder must be equal!\"\n",
" assert encoder.n_layers == decoder.n_layers, \\\n",
" \"Encoder and decoder must have equal number of layers!\"\n",
" \n",
" def forward(self, src, trg, teacher_forcing_ratio = 0.5):\n",
" \n",
" #src = [src len, batch size]\n",
" #trg = [trg len, batch size]\n",
" #teacher_forcing_ratio is probability to use teacher forcing\n",
" #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time\n",
" \n",
" batch_size = trg.shape[1]\n",
" trg_len = trg.shape[0]\n",
" trg_vocab_size = self.decoder.output_dim\n",
" \n",
" #tensor to store decoder outputs\n",
" outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)\n",
" \n",
" #last hidden state of the encoder is used as the initial hidden state of the decoder\n",
" hidden, cell = self.encoder(src)\n",
" \n",
" #first input to the decoder is the <sos> tokens\n",
" input = trg[0,:]\n",
" \n",
" for t in range(1, trg_len):\n",
" \n",
" #insert input token embedding, previous hidden and previous cell states\n",
" #receive output tensor (predictions) and new hidden and cell states\n",
" output, hidden, cell = self.decoder(input, hidden, cell)\n",
" \n",
" #place predictions in a tensor holding predictions for each token\n",
" outputs[t] = output\n",
" \n",
" #decide if we are going to use teacher forcing or not\n",
" teacher_force = random.random() < teacher_forcing_ratio\n",
" \n",
" #get the highest predicted token from our predictions\n",
" top1 = output.argmax(1) \n",
" \n",
" #if teacher forcing, use actual next token as next input\n",
" #if not, use predicted token\n",
" input = trg[t] if teacher_force else top1\n",
" \n",
" return outputs"
],
"execution_count": 49,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "2vjqx2kQCzkl"
},
"source": [
"# Обучаем Seq2Seq Model\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "WsKgtYmgCzkl"
},
"source": [
"INPUT_DIM = len(SRC.vocab)\n",
"OUTPUT_DIM = len(TRG.vocab)\n",
"ENC_EMB_DIM = 256\n",
"DEC_EMB_DIM = 256\n",
"HID_DIM = 512\n",
"N_LAYERS = 2\n",
"ENC_DROPOUT = 0.5\n",
"DEC_DROPOUT = 0.5\n",
"\n",
"enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)\n",
"dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)\n",
"\n",
"model = Seq2Seq(enc, dec, device).to(device)"
],
"execution_count": 50,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "S6B4cLuACzkt"
},
"source": [
"Инициализируем веса $\\mathcal{U}(-0.08, 0.08)$.\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "CjgY--uZCzku",
"outputId": "38bf4192-5fb9-4df6-f415-3826e2960e58",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
}
},
"source": [
"def init_weights(m):\n",
" for name, param in m.named_parameters():\n",
" nn.init.uniform_(param.data, -0.08, 0.08)\n",
" \n",
"model.apply(init_weights)"
],
"execution_count": 51,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Seq2Seq(\n",
" (encoder): Encoder(\n",
" (embedding): Embedding(7855, 256)\n",
" (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" )\n",
" (decoder): Decoder(\n",
" (embedding): Embedding(5893, 256)\n",
" (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)\n",
" (fc_out): Linear(in_features=512, out_features=5893, bias=True)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" )\n",
")"
]
},
"metadata": {
"tags": []
},
"execution_count": 51
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d8yNFNT3Czkz"
},
"source": [
"Найдем количество параметров сетки"
]
},
{
"cell_type": "code",
"metadata": {
"id": "mjR0K_NTCzk2",
"outputId": "25fa8971-6e01-4f42-9e40-1be4d46ba841",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'Модель содержит {count_parameters(model):,} параметров для обучения')"
],
"execution_count": 52,
"outputs": [
{
"output_type": "stream",
"text": [
"Модель содержит 13,899,013 параметров для обучения\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NDBNKc0JCzk7"
},
"source": [
"Зададим оптимизатор Adam."
]
},
{
"cell_type": "code",
"metadata": {
"id": "qVK9y5QRCzk7"
},
"source": [
"optimizer = optim.Adam(model.parameters())"
],
"execution_count": 53,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "q1OxzrSeCzlB"
},
"source": [
"Зададим функцию потерь: `CrossEntropyLoss` "
]
},
{
"cell_type": "code",
"metadata": {
"id": "gf-YcngWCzlC"
},
"source": [
"TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]\n",
"\n",
"criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)"
],
"execution_count": 54,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZyvBWwrGCzlF"
},
"source": [
"Обучаем!\n",
"\n",
"На каждой итерации:\n",
"- формируем батч, $X$ и $Y$;\n",
"- обнуляем параметры предыдущей итерации;\n",
"- подаем в сеть $X$;\n",
"- подгоняем размерности тензора под функцию потерь;\n",
"- вычисляем значения градиента;\n",
"- выполняем усечение градиента;\n",
"- обновляем веса нейронной сети;\n",
"- вычисляем значение функции потерь.\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "cLJr7bHICzlG"
},
"source": [
"def train(model, iterator, optimizer, criterion, clip):\n",
" \n",
" model.train()\n",
" \n",
" epoch_loss = 0\n",
" \n",
" for i, batch in enumerate(iterator):\n",
" \n",
" src = batch.src\n",
" trg = batch.trg\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" output = model(src, trg)\n",
" \n",
" #trg = [trg len, batch size]\n",
" #output = [trg len, batch size, output dim]\n",
" \n",
" output_dim = output.shape[-1]\n",
" \n",
" output = output[1:].view(-1, output_dim)\n",
" trg = trg[1:].view(-1)\n",
" \n",
" #trg = [(trg len - 1) * batch size]\n",
" #output = [(trg len - 1) * batch size, output dim]\n",
" \n",
" loss = criterion(output, trg)\n",
" \n",
" loss.backward()\n",
" \n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
" \n",
" optimizer.step()\n",
" \n",
" epoch_loss += loss.item()\n",
" \n",
" return epoch_loss / len(iterator)"
],
"execution_count": 55,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "eQWtbvLQCzlL"
},
"source": [
"Функция валидации нейронной сети"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sKt6n_SpCzlM"
},
"source": [
"def evaluate(model, iterator, criterion):\n",
" \n",
" model.eval()\n",
" \n",
" epoch_loss = 0\n",
" \n",
" with torch.no_grad():\n",
" \n",
" for i, batch in enumerate(iterator):\n",
"\n",
" src = batch.src\n",
" trg = batch.trg\n",
"\n",
" output = model(src, trg, 0) #turn off teacher forcing\n",
"\n",
" #trg = [trg len, batch size]\n",
" #output = [trg len, batch size, output dim]\n",
"\n",
" output_dim = output.shape[-1]\n",
" \n",
" output = output[1:].view(-1, output_dim)\n",
" trg = trg[1:].view(-1)\n",
"\n",
" #trg = [(trg len - 1) * batch size]\n",
" #output = [(trg len - 1) * batch size, output dim]\n",
"\n",
" loss = criterion(output, trg)\n",
" \n",
" epoch_loss += loss.item()\n",
" \n",
" return epoch_loss / len(iterator)"
],
"execution_count": 56,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "MqbxRfNKCzlQ"
},
"source": [
"Функция для измерения времени вычисления на одной эпохе"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4fTIeDylCzlR"
},
"source": [
"def epoch_time(start_time, end_time):\n",
" elapsed_time = end_time - start_time\n",
" elapsed_mins = int(elapsed_time / 60)\n",
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
" return elapsed_mins, elapsed_secs"
],
"execution_count": 57,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "nIqluYfJCzlW"
},
"source": [
"Все таки обучаем!"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rKw3yaq-CzlX",
"outputId": "a467a891-fd5f-41c3-9b78-64291c891556",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 527
}
},
"source": [
"N_EPOCHS = 10\n",
"CLIP = 1\n",
"\n",
"best_valid_loss = float('inf')\n",
"\n",
"for epoch in range(N_EPOCHS):\n",
" \n",
" start_time = time.time()\n",
" \n",
" train_loss = train(model, train_iterator, optimizer, criterion, CLIP)\n",
" valid_loss = evaluate(model, valid_iterator, criterion)\n",
" \n",
" end_time = time.time()\n",
" \n",
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
" \n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), 'tut1-model.pt')\n",
" \n",
" print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')\n",
" print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')\n",
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')"
],
"execution_count": 58,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch: 01 | Time: 0m 30s\n",
"\tTrain Loss: 5.041 | Train PPL: 154.680\n",
"\t Val. Loss: 4.865 | Val. PPL: 129.692\n",
"Epoch: 02 | Time: 0m 30s\n",
"\tTrain Loss: 4.468 | Train PPL: 87.144\n",
"\t Val. Loss: 4.747 | Val. PPL: 115.236\n",
"Epoch: 03 | Time: 0m 30s\n",
"\tTrain Loss: 4.192 | Train PPL: 66.167\n",
"\t Val. Loss: 4.648 | Val. PPL: 104.383\n",
"Epoch: 04 | Time: 0m 30s\n",
"\tTrain Loss: 3.991 | Train PPL: 54.094\n",
"\t Val. Loss: 4.516 | Val. PPL: 91.478\n",
"Epoch: 05 | Time: 0m 30s\n",
"\tTrain Loss: 3.848 | Train PPL: 46.910\n",
"\t Val. Loss: 4.368 | Val. PPL: 78.879\n",
"Epoch: 06 | Time: 0m 30s\n",
"\tTrain Loss: 3.700 | Train PPL: 40.451\n",
"\t Val. Loss: 4.397 | Val. PPL: 81.209\n",
"Epoch: 07 | Time: 0m 30s\n",
"\tTrain Loss: 3.577 | Train PPL: 35.781\n",
"\t Val. Loss: 4.334 | Val. PPL: 76.212\n",
"Epoch: 08 | Time: 0m 30s\n",
"\tTrain Loss: 3.465 | Train PPL: 31.987\n",
"\t Val. Loss: 4.170 | Val. PPL: 64.689\n",
"Epoch: 09 | Time: 0m 30s\n",
"\tTrain Loss: 3.368 | Train PPL: 29.032\n",
"\t Val. Loss: 4.092 | Val. PPL: 59.874\n",
"Epoch: 10 | Time: 0m 30s\n",
"\tTrain Loss: 3.261 | Train PPL: 26.084\n",
"\t Val. Loss: 4.068 | Val. PPL: 58.427\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M4bQ_DBNCzla"
},
"source": [
"\n",
"Посмотрим, что получилось"
]
},
{
"cell_type": "code",
"metadata": {
"id": "x6Y6T2zPCzlb",
"outputId": "7046ba98-6d31-4249-8d8c-780ced87ddeb",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"model.load_state_dict(torch.load('tut1-model.pt'))\n",
"\n",
"test_loss = evaluate(model, test_iterator, criterion)\n",
"\n",
"print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')"
],
"execution_count": 59,
"outputs": [
{
"output_type": "stream",
"text": [
"| Test Loss: 4.063 | Test PPL: 58.121 |\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QQbY2CdcxRxC"
},
"source": [
"Выполним перевод предложения"
]
},
{
"cell_type": "code",
"metadata": {
"id": "_O5lFH8md4Tx",
"outputId": "d416977d-1e3c-4ed0-bc9d-e93f9787c1d0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"example_idx = 18\n",
"\n",
"src = vars(test_data.examples[example_idx])['src']\n",
"trg = vars(test_data.examples[example_idx])['trg']\n",
"\n",
"print(f'src = {src}')\n",
"print(f'trg = {trg}')"
],
"execution_count": 60,
"outputs": [
{
"output_type": "stream",
"text": [
"src = ['.', 'berg', 'einen', 'auf', 'klettert', 'shirt', 'gestreiften', 'im', 'person', 'die']\n",
"trg = ['the', 'person', 'in', 'the', 'striped', 'shirt', 'is', 'mountain', 'climbing', '.']\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "IdIXZnRFf1wz"
},
"source": [
"def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):\n",
"\n",
" model.eval()\n",
" \n",
" if isinstance(sentence, str):\n",
" nlp = spacy.load('de')\n",
" tokens = [token.text.lower() for token in nlp(sentence)]\n",
" else:\n",
" tokens = [token.lower() for token in sentence]\n",
"\n",
" tokens = [src_field.init_token] + tokens + [src_field.eos_token]\n",
" \n",
" src_indexes = [src_field.vocab.stoi[token] for token in tokens]\n",
" \n",
" src_tensor = torch.LongTensor(src_indexes).unsqueeze(1).to(device)\n",
"\n",
" src_len = torch.LongTensor([len(src_indexes)]).to(device)\n",
" \n",
" with torch.no_grad():\n",
" encoder_outputs, hidden = model.encoder(src_tensor)\n",
"\n",
" trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]\n",
"\n",
"\n",
" for i in range(max_len):\n",
" trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)\n",
" #print(trg_tensor)\n",
" with torch.no_grad():\n",
" output = model.decoder(trg_tensor, hidden, encoder_outputs)\n",
" \n",
" pred_token = output[0][0].argmax()\n",
" \n",
" trg_indexes.append(pred_token)\n",
" \n",
" if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:\n",
" break\n",
" \n",
" trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]\n",
" \n",
" return trg_tokens[1:]"
],
"execution_count": 67,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "F5N2q9JOy31f",
"outputId": "f0c661c9-4b71-44b2-eaae-b5a06988271a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 904
}
},
"source": [
"translation = translate_sentence(src, SRC, TRG, model, device)\n",
"\n",
"print(f'predicted trg = {translation}')"
],
"execution_count": 68,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor([2], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"tensor([55], device='cuda:0')\n",
"predicted trg = ['child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child', 'child']\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment