Skip to content

Instantly share code, notes, and snippets.

@snakers4
Created April 10, 2023 05:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save snakers4/89eab250c404d71e82a913b0b751488c to your computer and use it in GitHub Desktop.
Save snakers4/89eab250c404d71e82a913b0b751488c to your computer and use it in GitHub Desktop.
test2.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"id": "1XEMm5oo36Sm"
},
"outputs": [],
"source": [
"import os\n",
"import time\n",
"import glob\n",
"import numba\n",
"import random\n",
"import itertools\n",
"import torch\n",
"import torch.nn as nn\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from tqdm import tqdm\n",
"from sklearn.model_selection import train_test_split\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"from torch.utils.data import Dataset, DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"id": "a6-yAg4536Sp"
},
"outputs": [],
"source": [
"RANDOM_SEED = 42"
]
},
{
"cell_type": "code",
"source": [
"data = pd.read_csv('/content/train.csv')"
],
"metadata": {
"id": "U6Vg9a-0KNqY"
},
"execution_count": 61,
"outputs": []
},
{
"cell_type": "code",
"source": [
"data"
],
"metadata": {
"id": "YAlwu5KAKUh4",
"outputId": "7c4d63f3-7130-4014-8baf-2e3dbdff67d1",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
}
},
"execution_count": 62,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" id word stress num_syllables lemma\n",
"0 0 румяной 2 3 румяный\n",
"1 1 цифрами 1 3 цифра\n",
"2 2 слугами 1 3 слуга\n",
"3 3 выбирает 3 4 выбирать\n",
"4 4 управдом 3 3 управдом\n",
"... ... ... ... ... ...\n",
"63433 63433 экзамена 2 4 экзамен\n",
"63434 63434 культурой 2 3 культура\n",
"63435 63435 объемной 2 3 объемный\n",
"63436 63436 участком 2 3 участок\n",
"63437 63437 ташкента 2 3 ташкент\n",
"\n",
"[63438 rows x 5 columns]"
],
"text/html": [
"\n",
" <div id=\"df-edc4ed79-b743-47a8-a595-ff6279343978\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>word</th>\n",
" <th>stress</th>\n",
" <th>num_syllables</th>\n",
" <th>lemma</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>румяной</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>румяный</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>цифрами</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>цифра</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>слугами</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>слуга</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3</td>\n",
" <td>выбирает</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>выбирать</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4</td>\n",
" <td>управдом</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>управдом</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63433</th>\n",
" <td>63433</td>\n",
" <td>экзамена</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>экзамен</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63434</th>\n",
" <td>63434</td>\n",
" <td>культурой</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>культура</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63435</th>\n",
" <td>63435</td>\n",
" <td>объемной</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>объемный</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63436</th>\n",
" <td>63436</td>\n",
" <td>участком</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>участок</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63437</th>\n",
" <td>63437</td>\n",
" <td>ташкента</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>ташкент</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>63438 rows × 5 columns</p>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-edc4ed79-b743-47a8-a595-ff6279343978')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-edc4ed79-b743-47a8-a595-ff6279343978 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-edc4ed79-b743-47a8-a595-ff6279343978');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 62
}
]
},
{
"cell_type": "markdown",
"source": [
"### Preprocessing"
],
"metadata": {
"id": "bKe_fTixbfIO"
}
},
{
"cell_type": "code",
"source": [
"def stress_pos(word, stress):\n",
" res = np.zeros(len(word))\n",
" for i in range(len(word)):\n",
" if word[i] in ['а', 'о', 'у', 'ы', 'э', 'е', 'ё', 'и', 'ю', 'я']:\n",
" if stress == 1:\n",
" res[i] = 1\n",
" break\n",
" else: \n",
" stress -= 1 \n",
" return res"
],
"metadata": {
"id": "LUaoofm8KaUS"
},
"execution_count": 63,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%%time\n",
"\n",
"data['word_stress_pos'] = data.apply(lambda x: stress_pos(x.word, x.stress), axis=1)"
],
"metadata": {
"id": "MCL0OC5AQ10d",
"outputId": "3ea44f1a-50ef-448c-edad-2a97877665c3",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 64,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"CPU times: user 1.19 s, sys: 10.7 ms, total: 1.2 s\n",
"Wall time: 1.2 s\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"id": "QMJeYC9636St",
"outputId": "6c803ce0-0227-4003-87fa-e93738697237",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"CPU times: user 37.9 ms, sys: 26.2 ms, total: 64.2 ms\n",
"Wall time: 65.6 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"data['word_list'] = data.word.map(list)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"id": "FV6qIcC336Su",
"outputId": "00d1f8c4-41a2-4c90-ab61-d599dd8fcdf5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" id word stress num_syllables lemma \\\n",
"49671 49671 священником 2 4 священник \n",
"50707 50707 игрушечный 2 4 игрушечный \n",
"29760 29760 полмиллиарда 4 5 полмиллиард \n",
"61955 61955 байгора 1 3 байгора \n",
"28681 28681 цепочки 2 3 цепочка \n",
"\n",
" word_stress_pos \\\n",
"49671 [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ... \n",
"50707 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n",
"29760 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ... \n",
"61955 [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0] \n",
"28681 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] \n",
"\n",
" word_list \n",
"49671 [с, в, я, щ, е, н, н, и, к, о, м] \n",
"50707 [и, г, р, у, ш, е, ч, н, ы, й] \n",
"29760 [п, о, л, м, и, л, л, и, а, р, д, а] \n",
"61955 [б, а, й, г, о, р, а] \n",
"28681 [ц, е, п, о, ч, к, и] "
],
"text/html": [
"\n",
" <div id=\"df-cf066de4-3646-4f64-afb8-4077910571cf\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>word</th>\n",
" <th>stress</th>\n",
" <th>num_syllables</th>\n",
" <th>lemma</th>\n",
" <th>word_stress_pos</th>\n",
" <th>word_list</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>49671</th>\n",
" <td>49671</td>\n",
" <td>священником</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>священник</td>\n",
" <td>[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
" <td>[с, в, я, щ, е, н, н, и, к, о, м]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50707</th>\n",
" <td>50707</td>\n",
" <td>игрушечный</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>игрушечный</td>\n",
" <td>[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
" <td>[и, г, р, у, ш, е, ч, н, ы, й]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29760</th>\n",
" <td>29760</td>\n",
" <td>полмиллиарда</td>\n",
" <td>4</td>\n",
" <td>5</td>\n",
" <td>полмиллиард</td>\n",
" <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...</td>\n",
" <td>[п, о, л, м, и, л, л, и, а, р, д, а]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61955</th>\n",
" <td>61955</td>\n",
" <td>байгора</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>байгора</td>\n",
" <td>[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]</td>\n",
" <td>[б, а, й, г, о, р, а]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28681</th>\n",
" <td>28681</td>\n",
" <td>цепочки</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>цепочка</td>\n",
" <td>[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]</td>\n",
" <td>[ц, е, п, о, ч, к, и]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-cf066de4-3646-4f64-afb8-4077910571cf')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-cf066de4-3646-4f64-afb8-4077910571cf button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-cf066de4-3646-4f64-afb8-4077910571cf');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 66
}
],
"source": [
"data.sample(5)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"id": "q5LCEprm36Su",
"outputId": "6f278542-1266-4a19-e0a6-412a62953379",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"20"
]
},
"metadata": {},
"execution_count": 67
}
],
"source": [
"max_sequence_len = np.max(data.word.str.len())\n",
"max_sequence_len"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"id": "T1ln0-f436Su"
},
"outputs": [],
"source": [
"def flatten(array):\n",
" for item in array:\n",
" if isinstance(item, list):\n",
" yield from flatten(item)\n",
" else:\n",
" yield item\n",
"\n",
"\n",
"class SequenceTokenizer:\n",
" \n",
" def __init__(self):\n",
" self.word2index = {}\n",
" self.index2word = {}\n",
" self.oov_token ='<UNK>'\n",
" self.oov_token_index = 0\n",
" \n",
" def fit(self, sequence):\n",
" self.index2word = dict(enumerate([self.oov_token] + sorted(set(flatten(sequence))), 1))\n",
" self.word2index = {v:k for k,v in self.index2word.items()}\n",
" self.oov_token_index = self.word2index.get(self.oov_token)\n",
" return self\n",
" \n",
" def transform(self, X):\n",
" res = []\n",
" for line in X:\n",
" res.append([self.word2index.get(item, self.oov_token_index) for item in line])\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"id": "h4Rl1qJd36Sv"
},
"outputs": [],
"source": [
"tokenizer = SequenceTokenizer()\n",
"tokenizer.fit(data.word_list);"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"id": "xj1bSdhr36Sv",
"outputId": "08ef9c2b-69dc-4c1e-b0a2-311f2709fb0b",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'<UNK>': 1,\n",
" 'а': 2,\n",
" 'б': 3,\n",
" 'в': 4,\n",
" 'г': 5,\n",
" 'д': 6,\n",
" 'е': 7,\n",
" 'ж': 8,\n",
" 'з': 9,\n",
" 'и': 10,\n",
" 'й': 11,\n",
" 'к': 12,\n",
" 'л': 13,\n",
" 'м': 14,\n",
" 'н': 15,\n",
" 'о': 16,\n",
" 'п': 17,\n",
" 'р': 18,\n",
" 'с': 19,\n",
" 'т': 20,\n",
" 'у': 21,\n",
" 'ф': 22,\n",
" 'х': 23,\n",
" 'ц': 24,\n",
" 'ч': 25,\n",
" 'ш': 26,\n",
" 'щ': 27,\n",
" 'ъ': 28,\n",
" 'ы': 29,\n",
" 'ь': 30,\n",
" 'э': 31,\n",
" 'ю': 32,\n",
" 'я': 33,\n",
" 'ё': 34}"
]
},
"metadata": {},
"execution_count": 70
}
],
"source": [
"tokenizer.word2index"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"id": "46RLP83Z36Sv"
},
"outputs": [],
"source": [
"X = tokenizer.transform(data.word_list)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"id": "yInv9QA636Sv"
},
"outputs": [],
"source": [
"def pad_sequence(lst, max_seq=max_sequence_len):\n",
" if isinstance(lst[0], list):\n",
" return np.array([i + [0]*(max_seq-len(i)) for i in lst])\n",
" else:\n",
" lst + [0]*(max_seq-len(lst))"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"id": "YmDQ0jGP36Sw",
"outputId": "b191b59e-76bd-48b1-a2a9-e693a8379adc",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"CPU times: user 129 ms, sys: 5.91 ms, total: 134 ms\n",
"Wall time: 136 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"input_seq = pad_sequence(X)"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"id": "vdMK4Yrd36Sw",
"outputId": "18424cdb-f387-4380-a769-978e91d25b30",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(63438, 20)\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[18, 21, 14, ..., 0, 0, 0],\n",
" [24, 10, 22, ..., 0, 0, 0],\n",
" [19, 13, 21, ..., 0, 0, 0],\n",
" ...,\n",
" [16, 3, 28, ..., 0, 0, 0],\n",
" [21, 25, 2, ..., 0, 0, 0],\n",
" [20, 2, 26, ..., 0, 0, 0]])"
]
},
"metadata": {},
"execution_count": 74
}
],
"source": [
"print(input_seq.shape)\n",
"input_seq"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"id": "x5ATpkXb36Sx"
},
"outputs": [],
"source": [
"y = data.word_stress_pos.values"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"id": "izwtrpIy36Sx"
},
"outputs": [],
"source": [
"output_seq = zip(*itertools.zip_longest(*y, fillvalue=0))\n",
"output_seq = list(map(list, output_seq))\n",
"output_seq = np.array(output_seq).astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"id": "bqjHJ85O36Sx",
"outputId": "283c2a17-e121-4fd5-955c-bf17e1c04906",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(63438, 20)\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[0, 0, 0, ..., 0, 0, 0],\n",
" [0, 1, 0, ..., 0, 0, 0],\n",
" [0, 0, 1, ..., 0, 0, 0],\n",
" ...,\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 1, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0]])"
]
},
"metadata": {},
"execution_count": 77
}
],
"source": [
"print(output_seq.shape)\n",
"output_seq"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"id": "xJaSZ-mf36Sx"
},
"outputs": [],
"source": [
"(input_seq_train, input_seq_val, \n",
" output_seq_train, output_seq_val) = train_test_split(input_seq, \n",
" output_seq, \n",
" test_size=0.5, \n",
" random_state=RANDOM_SEED)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"id": "ezujxONV36Sy"
},
"outputs": [],
"source": [
"input_seq_train = torch.tensor(input_seq_train, dtype=torch.long).cuda()\n",
"input_seq_val = torch.tensor(input_seq_val, dtype=torch.long).cuda()\n",
"output_seq_train = torch.tensor(output_seq_train, dtype=torch.float).cuda()\n",
"output_seq_val = torch.tensor(output_seq_val, dtype=torch.float).cuda()"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {
"id": "gRsPsPpj36Sy"
},
"outputs": [],
"source": [
"class MyDataset(Dataset):\n",
" def __init__(self, dataset):\n",
" self.dataset = dataset\n",
" \n",
" def __getitem__(self, index):\n",
" data,target = self.dataset[index]\n",
" return data, target, index\n",
" \n",
" def __len__(self):\n",
" return len(self.dataset)"
]
},
{
"cell_type": "markdown",
"source": [
"### Model"
],
"metadata": {
"id": "92jorRHXcGmz"
}
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {
"id": "UZ5BFgo736Sy"
},
"outputs": [],
"source": [
"class LSTM_model(nn.Module):\n",
"\n",
" def __init__(self, embedding_dim, hidden_dim, vocab_size, target_size):\n",
" super(LSTM_model, self).__init__()\n",
" self.hidden_dim = hidden_dim\n",
"\n",
" self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n",
"\n",
" self.lstm = nn.LSTM(input_size=self.embeddings.embedding_dim,\n",
" hidden_size=hidden_dim,\n",
" num_layers=3,\n",
" batch_first=True,\n",
" bidirectional=True,\n",
" dropout = 0.05)\n",
" self.linear = nn.Linear(self.hidden_dim * 8 , 64)\n",
" self.batch_norm = nn.BatchNorm1d(self.hidden_dim * 8, affine=False)\n",
" self.relu = nn.ReLU()\n",
" self.dropout = nn.Dropout(0.1)\n",
" self.out = nn.Linear(64, target_size)\n",
"\n",
" def forward(self, x):\n",
" h_embeddings = self.embeddings(x)\n",
" \n",
" h_lstm, _ = self.lstm(h_embeddings)\n",
" d_1 = h_lstm[:,0,:]\n",
" d_2 = h_lstm[:,h_lstm.shape[1]//4,:]\n",
" d_3 = h_lstm[:,h_lstm.shape[1]*3//4,:]\n",
" d_4 = h_lstm[:,-1,:]\n",
" x = torch.cat((d_1, d_2, d_3, d_4), 1)\n",
" x = self.batch_norm(x)\n",
" x = self.linear(x)\n",
" x = self.relu(x)\n",
" x = self.dropout(x)\n",
" x = self.out(x)\n",
" y = nn.functional.softmax(x, dim=1)\n",
" return y"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {
"id": "uA2SJd7736Sy",
"outputId": "bae9801c-448d-4d31-b0d2-48be63acf28b",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LSTM_model(\n",
" (embeddings): Embedding(35, 64)\n",
" (lstm): LSTM(64, 64, num_layers=3, batch_first=True, dropout=0.05, bidirectional=True)\n",
" (linear): Linear(in_features=512, out_features=64, bias=True)\n",
" (batch_norm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
" (relu): ReLU()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (out): Linear(in_features=64, out_features=20, bias=True)\n",
")"
]
},
"metadata": {},
"execution_count": 82
}
],
"source": [
"model = LSTM_model(embedding_dim=64, \n",
" hidden_dim=64, \n",
" vocab_size=len(tokenizer.word2index) + 1, \n",
" target_size=max_sequence_len)\n",
"model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {
"id": "y6hRGxF836Sz"
},
"outputs": [],
"source": [
"loss_function = nn.BCEWithLogitsLoss()\n",
"optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {
"id": "NZD-Vfjh36Sz"
},
"outputs": [],
"source": [
"BATCH_SIZE = 256 * 2\n",
"\n",
"train = MyDataset(torch.utils.data.TensorDataset(input_seq_train, output_seq_train))\n",
"valid = MyDataset(torch.utils.data.TensorDataset(input_seq_val, output_seq_val))\n",
"\n",
"train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)\n",
"valid_loader = torch.utils.data.DataLoader(valid, batch_size=BATCH_SIZE, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jDEMfycQ36Sz"
},
"source": [
"### K-Fold training"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {
"id": "xIltVQnU36S0"
},
"outputs": [],
"source": [
"from sklearn.model_selection import KFold"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {
"id": "fN3FS28a36S1"
},
"outputs": [],
"source": [
"n_folds = 5"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {
"id": "p3-X0BHZ36S1"
},
"outputs": [],
"source": [
"kf = KFold(n_splits=n_folds, shuffle=True, random_state=RANDOM_SEED)"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {
"id": "m2x5XeCL36S1",
"outputId": "818fb1f8-c1c5-4293-8065-a42249b8bbd4",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(50750,) (12688,)\n",
"(50750,) (12688,)\n",
"(50750,) (12688,)\n",
"(50751,) (12687,)\n",
"(50751,) (12687,)\n"
]
}
],
"source": [
"for train_index, test_index in kf.split(X=input_seq, y=output_seq):\n",
" print(train_index.shape, test_index.shape)"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {
"id": "zonYLRzp36S1",
"outputId": "f0aed232-79a7-4cc7-93cc-a40c0402421a",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[18, 21, 14, ..., 0, 0, 0],\n",
" [24, 10, 22, ..., 0, 0, 0],\n",
" [19, 13, 21, ..., 0, 0, 0],\n",
" ...,\n",
" [16, 3, 28, ..., 0, 0, 0],\n",
" [21, 25, 2, ..., 0, 0, 0],\n",
" [20, 2, 26, ..., 0, 0, 0]])"
]
},
"metadata": {},
"execution_count": 89
}
],
"source": [
"input_seq"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m-qOODva36S2"
},
"source": [
"### Training loop"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {
"id": "G41070T636S2",
"outputId": "88019e40-3214-4b5e-8db5-5285d059ef0a",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch [1/50] progress = 98% \t loss=0.7073 \t acc=32.47% \n",
"Epoch [1/50] results:\t\t loss=0.7073\t acc=32.47%\t val_loss=0.7030\t val_acc=44.64%\t time=1.42s\n",
"------------------------------------------------------------------------------\n",
"Epoch [2/50] progress = 98% \t loss=0.6954 \t acc=57.92% \n",
"Epoch [2/50] results:\t\t loss=0.6954\t acc=57.92%\t val_loss=0.6914\t val_acc=65.59%\t time=1.47s\n",
"------------------------------------------------------------------------------\n",
"Epoch [3/50] progress = 98% \t loss=0.6900 \t acc=68.33% \n",
"Epoch [3/50] results:\t\t loss=0.6900\t acc=68.33%\t val_loss=0.6890\t val_acc=70.17%\t time=1.35s\n",
"------------------------------------------------------------------------------\n",
"Epoch [4/50] progress = 98% \t loss=0.6882 \t acc=71.96% \n",
"Epoch [4/50] results:\t\t loss=0.6882\t acc=71.96%\t val_loss=0.6883\t val_acc=71.45%\t time=1.68s\n",
"------------------------------------------------------------------------------\n",
"Epoch [5/50] progress = 98% \t loss=0.6875 \t acc=73.19% \n",
"Epoch [5/50] results:\t\t loss=0.6875\t acc=73.19%\t val_loss=0.6876\t val_acc=72.75%\t time=1.56s\n",
"------------------------------------------------------------------------------\n",
"Epoch [6/50] progress = 98% \t loss=0.6866 \t acc=75.19% \n",
"Epoch [6/50] results:\t\t loss=0.6866\t acc=75.19%\t val_loss=0.6870\t val_acc=74.08%\t time=1.32s\n",
"------------------------------------------------------------------------------\n",
"Epoch [7/50] progress = 98% \t loss=0.6860 \t acc=76.11% \n",
"Epoch [7/50] results:\t\t loss=0.6860\t acc=76.11%\t val_loss=0.6864\t val_acc=75.45%\t time=1.31s\n",
"------------------------------------------------------------------------------\n",
"Epoch [8/50] progress = 98% \t loss=0.6855 \t acc=77.21% \n",
"Epoch [8/50] results:\t\t loss=0.6855\t acc=77.21%\t val_loss=0.6861\t val_acc=75.99%\t time=1.47s\n",
"------------------------------------------------------------------------------\n",
"Epoch [9/50] progress = 98% \t loss=0.6852 \t acc=77.76% \n",
"Epoch [9/50] results:\t\t loss=0.6852\t acc=77.76%\t val_loss=0.6857\t val_acc=76.71%\t time=1.31s\n",
"------------------------------------------------------------------------------\n",
"Epoch [10/50] progress = 98% \t loss=0.6848 \t acc=78.58% \n",
"Epoch [10/50] results:\t\t loss=0.6848\t acc=78.58%\t val_loss=0.6857\t val_acc=76.74%\t time=1.35s\n",
"------------------------------------------------------------------------------\n",
"Epoch [11/50] progress = 98% \t loss=0.6842 \t acc=79.83% \n",
"Epoch [11/50] results:\t\t loss=0.6842\t acc=79.83%\t val_loss=0.6849\t val_acc=78.53%\t time=1.32s\n",
"------------------------------------------------------------------------------\n",
"Epoch [12/50] progress = 98% \t loss=0.6838 \t acc=80.68% \n",
"Epoch [12/50] results:\t\t loss=0.6838\t acc=80.68%\t val_loss=0.6848\t val_acc=78.61%\t time=1.48s\n",
"------------------------------------------------------------------------------\n",
"Epoch [13/50] progress = 98% \t loss=0.6838 \t acc=80.61% \n",
"Epoch [13/50] results:\t\t loss=0.6838\t acc=80.61%\t val_loss=0.6847\t val_acc=78.81%\t time=1.66s\n",
"------------------------------------------------------------------------------\n",
"Epoch [14/50] progress = 98% \t loss=0.6833 \t acc=81.56% \n",
"Epoch [14/50] results:\t\t loss=0.6833\t acc=81.56%\t val_loss=0.6844\t val_acc=79.43%\t time=1.51s\n",
"------------------------------------------------------------------------------\n",
"Epoch [15/50] progress = 98% \t loss=0.6832 \t acc=81.91% \n",
"Epoch [15/50] results:\t\t loss=0.6832\t acc=81.91%\t val_loss=0.6844\t val_acc=79.35%\t time=1.34s\n",
"------------------------------------------------------------------------------\n",
"Epoch [16/50] progress = 98% \t loss=0.6829 \t acc=82.49% \n",
"Epoch [16/50] results:\t\t loss=0.6829\t acc=82.49%\t val_loss=0.6841\t val_acc=79.96%\t time=1.35s\n",
"------------------------------------------------------------------------------\n",
"Epoch [17/50] progress = 98% \t loss=0.6827 \t acc=82.89% \n",
"Epoch [17/50] results:\t\t loss=0.6827\t acc=82.89%\t val_loss=0.6843\t val_acc=79.54%\t time=1.33s\n",
"------------------------------------------------------------------------------\n",
"Epoch [18/50] progress = 98% \t loss=0.6825 \t acc=83.25% \n",
"Epoch [18/50] results:\t\t loss=0.6825\t acc=83.25%\t val_loss=0.6840\t val_acc=80.11%\t time=1.34s\n",
"------------------------------------------------------------------------------\n",
"Epoch [19/50] progress = 98% \t loss=0.6823 \t acc=83.72% \n",
"Epoch [19/50] results:\t\t loss=0.6823\t acc=83.72%\t val_loss=0.6839\t val_acc=80.42%\t time=1.37s\n",
"------------------------------------------------------------------------------\n",
"Epoch [20/50] progress = 98% \t loss=0.6822 \t acc=83.85% \n",
"Epoch [20/50] results:\t\t loss=0.6822\t acc=83.85%\t val_loss=0.6838\t val_acc=80.65%\t time=1.50s\n",
"------------------------------------------------------------------------------\n",
"Epoch [21/50] progress = 98% \t loss=0.6820 \t acc=84.19% \n",
"Epoch [21/50] results:\t\t loss=0.6820\t acc=84.19%\t val_loss=0.6840\t val_acc=80.10%\t time=1.61s\n",
"------------------------------------------------------------------------------\n",
"Epoch [22/50] progress = 98% \t loss=0.6819 \t acc=84.51% \n",
"Epoch [22/50] results:\t\t loss=0.6819\t acc=84.51%\t val_loss=0.6836\t val_acc=80.95%\t time=1.51s\n",
"------------------------------------------------------------------------------\n",
"Epoch [23/50] progress = 98% \t loss=0.6817 \t acc=84.92% \n",
"Epoch [23/50] results:\t\t loss=0.6817\t acc=84.92%\t val_loss=0.6835\t val_acc=81.17%\t time=1.71s\n",
"------------------------------------------------------------------------------\n",
"Epoch [24/50] progress = 98% \t loss=0.6815 \t acc=85.27% \n",
"Epoch [24/50] results:\t\t loss=0.6815\t acc=85.27%\t val_loss=0.6836\t val_acc=80.90%\t time=2.25s\n",
"------------------------------------------------------------------------------\n",
"Epoch [25/50] progress = 98% \t loss=0.6815 \t acc=85.42% \n",
"Epoch [25/50] results:\t\t loss=0.6815\t acc=85.42%\t val_loss=0.6834\t val_acc=81.36%\t time=1.41s\n",
"------------------------------------------------------------------------------\n",
"Epoch [26/50] progress = 98% \t loss=0.6814 \t acc=85.62% \n",
"Epoch [26/50] results:\t\t loss=0.6814\t acc=85.62%\t val_loss=0.6834\t val_acc=81.44%\t time=1.49s\n",
"------------------------------------------------------------------------------\n",
"Epoch [27/50] progress = 98% \t loss=0.6812 \t acc=85.88% \n",
"Epoch [27/50] results:\t\t loss=0.6812\t acc=85.88%\t val_loss=0.6835\t val_acc=81.20%\t time=1.33s\n",
"------------------------------------------------------------------------------\n",
"Epoch [28/50] progress = 98% \t loss=0.6811 \t acc=86.17% \n",
"Epoch [28/50] results:\t\t loss=0.6811\t acc=86.17%\t val_loss=0.6832\t val_acc=81.71%\t time=1.33s\n",
"------------------------------------------------------------------------------\n",
"Epoch [29/50] progress = 98% \t loss=0.6809 \t acc=86.44% \n",
"Epoch [29/50] results:\t\t loss=0.6809\t acc=86.44%\t val_loss=0.6832\t val_acc=81.77%\t time=3.03s\n",
"------------------------------------------------------------------------------\n",
"Epoch [30/50] progress = 98% \t loss=0.6807 \t acc=86.77% \n",
"Epoch [30/50] results:\t\t loss=0.6807\t acc=86.77%\t val_loss=0.6832\t val_acc=81.76%\t time=1.34s\n",
"------------------------------------------------------------------------------\n",
"Epoch [31/50] progress = 98% \t loss=0.6806 \t acc=87.00% \n",
"Epoch [31/50] results:\t\t loss=0.6806\t acc=87.00%\t val_loss=0.6831\t val_acc=82.11%\t time=1.51s\n",
"------------------------------------------------------------------------------\n",
"Epoch [32/50] progress = 98% \t loss=0.6805 \t acc=87.17% \n",
"Epoch [32/50] results:\t\t loss=0.6805\t acc=87.17%\t val_loss=0.6829\t val_acc=82.33%\t time=1.33s\n",
"------------------------------------------------------------------------------\n",
"Epoch [33/50] progress = 98% \t loss=0.6805 \t acc=87.32% \n",
"Epoch [33/50] results:\t\t loss=0.6805\t acc=87.32%\t val_loss=0.6829\t val_acc=82.48%\t time=1.32s\n",
"------------------------------------------------------------------------------\n",
"Epoch [34/50] progress = 98% \t loss=0.6803 \t acc=87.62% \n",
"Epoch [34/50] results:\t\t loss=0.6803\t acc=87.62%\t val_loss=0.6829\t val_acc=82.31%\t time=1.31s\n",
"------------------------------------------------------------------------------\n",
"Epoch [35/50] progress = 98% \t loss=0.6802 \t acc=87.90% \n",
"Epoch [35/50] results:\t\t loss=0.6802\t acc=87.90%\t val_loss=0.6828\t val_acc=82.62%\t time=1.49s\n",
"------------------------------------------------------------------------------\n",
"Epoch [36/50] progress = 98% \t loss=0.6800 \t acc=88.25% \n",
"Epoch [36/50] results:\t\t loss=0.6800\t acc=88.25%\t val_loss=0.6827\t val_acc=82.84%\t time=1.33s\n",
"------------------------------------------------------------------------------\n",
"Epoch [37/50] progress = 98% \t loss=0.6799 \t acc=88.48% \n",
"Epoch [37/50] results:\t\t loss=0.6799\t acc=88.48%\t val_loss=0.6828\t val_acc=82.47%\t time=1.63s\n",
"------------------------------------------------------------------------------\n",
"Epoch [38/50] progress = 98% \t loss=0.6798 \t acc=88.66% \n",
"Epoch [38/50] results:\t\t loss=0.6798\t acc=88.66%\t val_loss=0.6828\t val_acc=82.68%\t time=2.02s\n",
"------------------------------------------------------------------------------\n",
"Epoch [39/50] progress = 98% \t loss=0.6797 \t acc=88.83% \n",
"Epoch [39/50] results:\t\t loss=0.6797\t acc=88.83%\t val_loss=0.6826\t val_acc=83.10%\t time=1.33s\n",
"------------------------------------------------------------------------------\n",
"Epoch [40/50] progress = 98% \t loss=0.6797 \t acc=89.00% \n",
"Epoch [40/50] results:\t\t loss=0.6797\t acc=89.00%\t val_loss=0.6826\t val_acc=82.97%\t time=1.48s\n",
"------------------------------------------------------------------------------\n",
"Epoch [41/50] progress = 98% \t loss=0.6797 \t acc=88.96% \n",
"Epoch [41/50] results:\t\t loss=0.6797\t acc=88.96%\t val_loss=0.6825\t val_acc=83.18%\t time=1.31s\n",
"------------------------------------------------------------------------------\n",
"Epoch [42/50] progress = 98% \t loss=0.6795 \t acc=89.27% \n",
"Epoch [42/50] results:\t\t loss=0.6795\t acc=89.27%\t val_loss=0.6825\t val_acc=83.26%\t time=1.31s\n",
"------------------------------------------------------------------------------\n",
"Epoch [43/50] progress = 98% \t loss=0.6795 \t acc=89.32% \n",
"Epoch [43/50] results:\t\t loss=0.6795\t acc=89.32%\t val_loss=0.6825\t val_acc=83.19%\t time=1.32s\n",
"------------------------------------------------------------------------------\n",
"Epoch [44/50] progress = 98% \t loss=0.6794 \t acc=89.50% \n",
"Epoch [44/50] results:\t\t loss=0.6794\t acc=89.50%\t val_loss=0.6824\t val_acc=83.39%\t time=1.49s\n",
"------------------------------------------------------------------------------\n",
"Epoch [45/50] progress = 98% \t loss=0.6794 \t acc=89.50% \n",
"Epoch [45/50] results:\t\t loss=0.6794\t acc=89.50%\t val_loss=0.6824\t val_acc=83.42%\t time=1.49s\n",
"------------------------------------------------------------------------------\n",
"Epoch [46/50] progress = 98% \t loss=0.6793 \t acc=89.73% \n",
"Epoch [46/50] results:\t\t loss=0.6793\t acc=89.73%\t val_loss=0.6824\t val_acc=83.43%\t time=1.65s\n",
"------------------------------------------------------------------------------\n",
"Epoch [47/50] progress = 98% \t loss=0.6792 \t acc=89.94% \n",
"Epoch [47/50] results:\t\t loss=0.6792\t acc=89.94%\t val_loss=0.6824\t val_acc=83.47%\t time=1.48s\n",
"------------------------------------------------------------------------------\n",
"Epoch [48/50] progress = 98% \t loss=0.6792 \t acc=89.92% \n",
"Epoch [48/50] results:\t\t loss=0.6792\t acc=89.92%\t val_loss=0.6823\t val_acc=83.59%\t time=1.32s\n",
"------------------------------------------------------------------------------\n",
"Epoch [49/50] progress = 98% \t loss=0.6791 \t acc=90.01% \n",
"Epoch [49/50] results:\t\t loss=0.6791\t acc=90.01%\t val_loss=0.6825\t val_acc=83.20%\t time=1.31s\n",
"------------------------------------------------------------------------------\n",
"Epoch [50/50] progress = 98% \t loss=0.6791 \t acc=90.14% \n",
"Epoch [50/50] results:\t\t loss=0.6791\t acc=90.14%\t val_loss=0.6823\t val_acc=83.66%\t time=1.52s\n",
"------------------------------------------------------------------------------\n"
]
}
],
"source": [
"n_epochs = 50\n",
"history = {'train': {}, 'val': {}}\n",
"teacher_forcing_ratio = 0.5\n",
"\n",
"for epoch in range(1, n_epochs + 1):\n",
" start_time = time.time()\n",
" \n",
" model.train()\n",
"\n",
" avg_loss, total_loss, avg_acc, total_acc = 0., 0., 0., 0.\n",
" for i, (x_batch, y_batch, index) in enumerate(train_loader):\n",
" y_pred = model(x_batch)\n",
" loss = loss_function(y_pred, y_batch)\n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" equal = torch.eq(torch.argmax(y_pred, axis=1), torch.argmax(y_batch, axis=1))\n",
" batch_acc = int(equal.sum(-1)) / y_batch.shape[0]\n",
" batch_loss = loss.item()\n",
" \n",
" total_acc += batch_acc\n",
" total_loss += batch_loss\n",
" print(f\"\\rEpoch [{epoch}/{n_epochs}] \"\n",
" f\" progress = {round(i/len(train_loader)*100)}% \"\n",
" f\"\\t loss={total_loss / (i + 1):.4f} \"\n",
" f\"\\t acc={total_acc / (i + 1) * 100:.2f}% \", end='')\n",
" avg_loss = total_loss / len(train_loader)\n",
" avg_acc = total_acc / len(train_loader)\n",
" history['train']['loss'] = history.get('train', {}).get('loss', []) + [avg_loss]\n",
" history['train']['accuracy'] = history.get('train', {}).get('accuracy', []) + [avg_acc]\n",
" \n",
" model.eval()\n",
"\n",
" \n",
" avg_val_loss, total_val_loss, avg_val_acc, total_val_acc = 0., 0., 0., 0.\n",
" for i, (x_batch, y_batch, index) in enumerate(valid_loader):\n",
" y_pred = model(x_batch).detach()\n",
" val_loss = loss_function(y_pred, y_batch)\n",
" \n",
" equal = torch.eq(torch.argmax(y_pred, axis=1), torch.argmax(y_batch, axis=1))\n",
" batch_val_acc = int(equal.sum(-1)) / y_batch.shape[0]\n",
" batch_val_loss = val_loss.item()\n",
" \n",
" total_val_acc += batch_val_acc\n",
" total_val_loss += batch_val_loss\n",
" avg_val_loss = total_val_loss / len(valid_loader)\n",
" avg_val_acc = total_val_acc / len(valid_loader)\n",
" history['val']['loss'] = history.get('val', {}).get('loss', []) + [avg_val_loss]\n",
" history['val']['accuracy'] = history.get('val', {}).get('accuracy', []) + [avg_val_acc]\n",
" \n",
" elapsed_time = time.time() - start_time \n",
" print(f\"\\nEpoch [{epoch}/{n_epochs}] results:\"\n",
" f\"\\t\\t loss={avg_loss:.4f}\"\n",
" f\"\\t acc={avg_acc * 100:.2f}%\"\n",
" f\"\\t val_loss={avg_val_loss:.4f}\"\n",
" f\"\\t val_acc={avg_val_acc * 100:.2f}%\"\n",
" f\"\\t time={elapsed_time:.2f}s\")\n",
" print(\"-\"*78)"
]
},
{
"cell_type": "code",
"source": [
"torch.save(model.state_dict(),\"accentor.pt\")"
],
"metadata": {
"id": "t_KzF99MbYZ2"
},
"execution_count": 93,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {
"id": "gLj87-3T36S2",
"outputId": "7d666ca6-0865-4342-b10b-9320d92331b4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 718
}
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
],
"source": [
"plt.figure(figsize=(12, 8))\n",
"\n",
"plt.plot(history['train']['loss'])\n",
"plt.plot(history['val']['loss'])\n",
"plt.title('model loss')\n",
"plt.ylabel('val')\n",
"plt.xlabel('epoch')\n",
"plt.xticks(np.arange(len(history['train']['loss'])), np.arange(1, len(history['train']['loss']) + 1))\n",
"plt.legend(['train', 'val'], loc='upper right')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {
"id": "149EM4q-36S2",
"outputId": "b11e27f9-7cb1-4690-8468-f41a9dbfe8f4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 718
}
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
],
"source": [
"plt.figure(figsize=(12, 8))\n",
"\n",
"plt.plot(history['train']['accuracy'])\n",
"plt.plot(history['val']['accuracy'])\n",
"plt.title('model accuracy')\n",
"plt.ylabel('accuracy')\n",
"plt.xlabel('epoch')\n",
"plt.xticks(np.arange(len(history['train']['accuracy'])), np.arange(1, len(history['train']['accuracy']) + 1))\n",
"plt.legend(['train', 'val'], loc='lower right')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5DaCsCk836S3"
},
"source": [
"### Predict"
]
},
{
"cell_type": "code",
"source": [
"test = pd.read_csv('/content/test.csv')"
],
"metadata": {
"id": "SZuNwt_0VD4H"
},
"execution_count": 96,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"id": "6ZkZMm41WD6t",
"outputId": "6b99b007-964f-4bf6-bd3e-1ea4be497869"
},
"execution_count": 97,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" id word num_syllables lemma\n",
"0 0 эпилепсия 5 эпилепсия\n",
"1 1 относящейся 5 относиться\n",
"2 2 размышлениями 6 размышление\n",
"3 3 модемы 3 модем\n",
"4 4 солнц 1 солнце\n",
"... ... ... ... ...\n",
"29955 29955 донбасса 3 донбасс\n",
"29956 29956 обложка 3 обложка\n",
"29957 29957 правителя 4 правитель\n",
"29958 29958 шерстяной 3 шерстяной\n",
"29959 29959 оптимизации 6 оптимизация\n",
"\n",
"[29960 rows x 4 columns]"
],
"text/html": [
"\n",
" <div id=\"df-298679aa-e94c-4753-88bc-8080a262015b\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>word</th>\n",
" <th>num_syllables</th>\n",
" <th>lemma</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>эпилепсия</td>\n",
" <td>5</td>\n",
" <td>эпилепсия</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>относящейся</td>\n",
" <td>5</td>\n",
" <td>относиться</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>размышлениями</td>\n",
" <td>6</td>\n",
" <td>размышление</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3</td>\n",
" <td>модемы</td>\n",
" <td>3</td>\n",
" <td>модем</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4</td>\n",
" <td>солнц</td>\n",
" <td>1</td>\n",
" <td>солнце</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29955</th>\n",
" <td>29955</td>\n",
" <td>донбасса</td>\n",
" <td>3</td>\n",
" <td>донбасс</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29956</th>\n",
" <td>29956</td>\n",
" <td>обложка</td>\n",
" <td>3</td>\n",
" <td>обложка</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29957</th>\n",
" <td>29957</td>\n",
" <td>правителя</td>\n",
" <td>4</td>\n",
" <td>правитель</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29958</th>\n",
" <td>29958</td>\n",
" <td>шерстяной</td>\n",
" <td>3</td>\n",
" <td>шерстяной</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29959</th>\n",
" <td>29959</td>\n",
" <td>оптимизации</td>\n",
" <td>6</td>\n",
" <td>оптимизация</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>29960 rows × 4 columns</p>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-298679aa-e94c-4753-88bc-8080a262015b')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-298679aa-e94c-4753-88bc-8080a262015b button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-298679aa-e94c-4753-88bc-8080a262015b');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 97
}
]
},
{
"cell_type": "code",
"source": [
"def stress_pred(words):\n",
" tokens = pad_sequence(tokenizer.transform(words))\n",
" sequences = torch.tensor(tokens, dtype=torch.long).cuda()\n",
" preds = model(sequences)\n",
" indeces = torch.argmax(preds, axis=1)\n",
" indeces = indeces.to('cpu').numpy()\n",
" res = []\n",
" for i in range(len(indeces)):\n",
" pos = indeces[i]\n",
" coun = 1\n",
" for j in range(pos - 1,-1,-1,):\n",
" if words[i][j] in {'а', 'о', 'у', 'ы', 'э', 'е', 'ё', 'и', 'ю', 'я'}:\n",
" coun += 1\n",
" res.append(coun)\n",
" return res"
],
"metadata": {
"id": "xK1ZTbtRd644"
},
"execution_count": 98,
"outputs": []
},
{
"cell_type": "code",
"source": [
"words = test['word'].tolist()"
],
"metadata": {
"id": "ajqABWgCYJ0n"
},
"execution_count": 99,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pred_stress = stress_pred(words)"
],
"metadata": {
"id": "BnQIjobDlnoP"
},
"execution_count": 100,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test['pred_stress'] = pred_stress"
],
"metadata": {
"id": "1UClWDbuZYWK"
},
"execution_count": 101,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test"
],
"metadata": {
"id": "eaHU-R2LmII3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"outputId": "5d143eb5-f0c3-41c7-d982-ece4bf5596f8"
},
"execution_count": 102,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" id word num_syllables lemma pred_stress\n",
"0 0 эпилепсия 5 эпилепсия 3\n",
"1 1 относящейся 5 относиться 3\n",
"2 2 размышлениями 6 размышление 3\n",
"3 3 модемы 3 модем 2\n",
"4 4 солнц 1 солнце 1\n",
"... ... ... ... ... ...\n",
"29955 29955 донбасса 3 донбасс 2\n",
"29956 29956 обложка 3 обложка 2\n",
"29957 29957 правителя 4 правитель 2\n",
"29958 29958 шерстяной 3 шерстяной 2\n",
"29959 29959 оптимизации 6 оптимизация 4\n",
"\n",
"[29960 rows x 5 columns]"
],
"text/html": [
"\n",
" <div id=\"df-58d65a88-788e-4cef-aa6e-c99c96833dc8\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>word</th>\n",
" <th>num_syllables</th>\n",
" <th>lemma</th>\n",
" <th>pred_stress</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>эпилепсия</td>\n",
" <td>5</td>\n",
" <td>эпилепсия</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>относящейся</td>\n",
" <td>5</td>\n",
" <td>относиться</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>размышлениями</td>\n",
" <td>6</td>\n",
" <td>размышление</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3</td>\n",
" <td>модемы</td>\n",
" <td>3</td>\n",
" <td>модем</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4</td>\n",
" <td>солнц</td>\n",
" <td>1</td>\n",
" <td>солнце</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29955</th>\n",
" <td>29955</td>\n",
" <td>донбасса</td>\n",
" <td>3</td>\n",
" <td>донбасс</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29956</th>\n",
" <td>29956</td>\n",
" <td>обложка</td>\n",
" <td>3</td>\n",
" <td>обложка</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29957</th>\n",
" <td>29957</td>\n",
" <td>правителя</td>\n",
" <td>4</td>\n",
" <td>правитель</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29958</th>\n",
" <td>29958</td>\n",
" <td>шерстяной</td>\n",
" <td>3</td>\n",
" <td>шерстяной</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29959</th>\n",
" <td>29959</td>\n",
" <td>оптимизации</td>\n",
" <td>6</td>\n",
" <td>оптимизация</td>\n",
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>29960 rows × 5 columns</p>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-58d65a88-788e-4cef-aa6e-c99c96833dc8')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-58d65a88-788e-4cef-aa6e-c99c96833dc8 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-58d65a88-788e-4cef-aa6e-c99c96833dc8');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 102
}
]
},
{
"cell_type": "code",
"source": [
"test.to_csv('pred.csv')"
],
"metadata": {
"id": "p9M0Oei0Z16J"
},
"execution_count": 103,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ujctqO-d36S4"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "--8Nn7QD36S4"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Zy4DN7mL36S4"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tqzDD25l36S4"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.10 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
}
},
"colab": {
"provenance": []
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment