Skip to content

Instantly share code, notes, and snippets.

@seanbenhur
Created January 25, 2021 05:59
Show Gist options
  • Save seanbenhur/72024be5b416c70121a8741323b27dbd to your computer and use it in GitHub Desktop.
Save seanbenhur/72024be5b416c70121a8741323b27dbd to your computer and use it in GitHub Desktop.
MLT .ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "MLT .ipynb",
"provenance": [],
"mount_file_id": "19g3duD087i-qbuq3XutAsxUkD9rPx-Cm",
"authorship_tag": "ABX9TyOUYUTOfsDx2EpDyMYnS8b1",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/seanbenhur/72024be5b416c70121a8741323b27dbd/mlt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XmWpFnT1jy_F",
"outputId": "a1cc29cc-c55f-48f6-b2d5-82fff06102dc"
},
"source": [
"!pip install torchtext==0.6.0"
],
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: torchtext==0.6.0 in /usr/local/lib/python3.6/dist-packages (0.6.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (2.23.0)\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (1.7.0+cu101)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (1.15.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (1.19.5)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (4.41.1)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (0.1.95)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (2020.12.5)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (1.24.3)\n",
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->torchtext==0.6.0) (0.16.0)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->torchtext==0.6.0) (3.7.4.3)\n",
"Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->torchtext==0.6.0) (0.8)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2_dl5Kn2o0Gj",
"outputId": "716550c1-6484-48d3-ced0-55872956b2b4"
},
"source": [
"!python -m spacy download de_core_news_sm"
],
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: de_core_news_sm==2.2.5 from https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-2.2.5/de_core_news_sm-2.2.5.tar.gz#egg=de_core_news_sm==2.2.5 in /usr/local/lib/python3.6/dist-packages (2.2.5)\n",
"Requirement already satisfied: spacy>=2.2.2 in /usr/local/lib/python3.6/dist-packages (from de_core_news_sm==2.2.5) (2.2.4)\n",
"Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.5)\n",
"Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.0)\n",
"Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.19.5)\n",
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.5)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (51.3.3)\n",
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (4.41.1)\n",
"Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (0.8.0)\n",
"Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.1.3)\n",
"Requirement already satisfied: thinc==7.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (7.4.0)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (3.0.5)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (2.0.5)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (2.23.0)\n",
"Requirement already satisfied: blis<0.5.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (0.4.1)\n",
"Requirement already satisfied: importlib-metadata>=0.20; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.3.0)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (2.10)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (1.24.3)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (2020.12.5)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < \"3.8\"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.4.0)\n",
"Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < \"3.8\"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.7.4.3)\n",
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
"You can now load the model via spacy.load('de_core_news_sm')\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "bnw2p2LSilxM"
},
"source": [
"import torchtext\r\n",
"from torchtext import data\r\n",
"from torchtext.data import Field,BucketIterator\r\n",
"from torchtext.datasets import Multi30k\r\n",
"from torchtext.data.metrics import bleu_score\r\n",
"import random\r\n",
"import torch\r\n",
"import torch.nn as nn\r\n",
"import torch.nn.functional as F\r\n",
"import torch.optim as optim \r\n",
"import numpy as np\r\n",
"import spacy"
],
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "DcwuzwVgFMo3"
},
"source": [
"import os\r\n",
"os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\""
],
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kTMFT1sYkIhg"
},
"source": [
"def set_seed(seed):\r\n",
" random.seed(seed)\r\n",
" np.random.seed(seed)\r\n",
" torch.manual_seed(seed)\r\n",
" torch.cuda.manual_seed(seed)\r\n",
" torch.backends.cudnn.deterministic = True"
],
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Emi5IU8Nld1Y"
},
"source": [
"seed = 1234\r\n",
"set_seed(seed)"
],
"execution_count": 26,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BhEo62Y1mknd"
},
"source": [
"spacy_en = spacy.load('en')\r\n",
"spacy_de = spacy.load('de_core_news_sm')"
],
"execution_count": 27,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4KE5iQ8llyEb"
},
"source": [
"def tokenize_german(text):\r\n",
" \"\"\"\r\n",
" Tokenizes german text using spacy\r\n",
" \"\"\"\r\n",
" return [tok.text for tok in spacy_de.tokenizer(text)]\r\n",
"\r\n",
"\r\n",
"def tokenize_english(text):\r\n",
" \"\"\"\r\n",
" Tokenizes english text using spacy\r\n",
" \"\"\"\r\n",
" return [tok.text for tok in spacy_en.tokenizer(text)]"
],
"execution_count": 28,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "GtT-9UtAm3WG"
},
"source": [
"SRC = Field(tokenize = tokenize_german, \r\n",
" init_token = '<sos>', \r\n",
" eos_token = '<eos>', \r\n",
" lower = True, \r\n",
" batch_first = True)\r\n",
"\r\n",
"TRG = Field(tokenize = tokenize_english, \r\n",
" init_token = '<sos>', \r\n",
" eos_token = '<eos>', \r\n",
" lower = True, \r\n",
" batch_first = True)"
],
"execution_count": 29,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Y6A7y_zQoN-y"
},
"source": [
"english = data.Field(tokenize_english,\r\n",
" lower = True,\r\n",
" init_token = \"<sos>\",\r\n",
" eos_token=\"<eos>\")\r\n",
"\r\n",
"german = data.Field(tokenize_german,\r\n",
" lower = True,\r\n",
" init_token = \"<sos>\",\r\n",
" eos_token= \"<eos>\")"
],
"execution_count": 30,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BLs5syhsqypH"
},
"source": [
"train_data, valid_data, test_data = Multi30k.splits(\r\n",
" exts=(\".de\",\".en\"), fields = (english,german)\r\n",
")"
],
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "iyJkcyM5yFoG"
},
"source": [
"german.build_vocab(train_data,max_size=10000,min_freq=2)\r\n",
"english.build_vocab(train_data,max_size=10000,min_freq=2)"
],
"execution_count": 32,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "nVsdO5CXrM76"
},
"source": [
"##Implementing the model\r\n",
"class Transformer(nn.Module):\r\n",
" def __init__(self,\r\n",
" embedding_size,\r\n",
" src_vocab_size,\r\n",
" trg_vocab_size,\r\n",
" src_pad_idx,\r\n",
" num_heads,\r\n",
" num_encoder_layers,\r\n",
" num_decoder_layers,\r\n",
" forward_expansion,\r\n",
" dropout,\r\n",
" max_len,\r\n",
" device):\r\n",
" super().__init__()\r\n",
" self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)\r\n",
" self.src_position_embedding = nn.Embedding(max_len, embedding_size)\r\n",
" self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)\r\n",
" self.trg_position_embedding = nn.Embedding(max_len, embedding_size)\r\n",
"\r\n",
" self.device = device\r\n",
" self.transformer = nn.Transformer(\r\n",
" embedding_size,\r\n",
" num_heads,\r\n",
" num_encoder_layers,\r\n",
" num_decoder_layers,\r\n",
" forward_expansion,\r\n",
" dropout\r\n",
" )\r\n",
" self.fc_out = nn.Linear(embedding_size,trg_vocab_size)\r\n",
" self.dropout = nn.Dropout(dropout)\r\n",
" self.src_pad_idx = src_pad_idx\r\n",
"\r\n",
" def make_src_mask(self,src):\r\n",
" src_mask = src.transpose(0,1) == self.src_pad_idx\r\n",
" #(N,src_len)\r\n",
" return src_mask.to(self.device)\r\n",
"\r\n",
" def forward(self,src,trg):\r\n",
" src_seq_len, N = src.shape\r\n",
" trg_seq_len, N = trg.shape\r\n",
"\r\n",
" src_positions = (\r\n",
" torch.arange(0, src_seq_len)\r\n",
" .unsqueeze(1)\r\n",
" .expand(src_seq_len, N)\r\n",
" .to(self.device)\r\n",
" )\r\n",
"\r\n",
" trg_positions = (\r\n",
" torch.arange(0, trg_seq_len)\r\n",
" .unsqueeze(1)\r\n",
" .expand(trg_seq_len, N)\r\n",
" .to(self.device)\r\n",
" )\r\n",
"\r\n",
" embed_src = self.dropout(\r\n",
" (self.src_word_embedding(src) + self.src_position_embedding(src_positions))\r\n",
" )\r\n",
" embed_trg = self.dropout(\r\n",
" (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))\r\n",
" )\r\n",
"\r\n",
" src_padding_mask = self.make_src_mask(src)\r\n",
" trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).to(self.device)\r\n",
"\r\n",
" out = self.transformer(\r\n",
" embed_src,\r\n",
" embed_trg,\r\n",
" src_key_padding_mask=src_padding_mask,\r\n",
" tgt_mask=trg_mask,\r\n",
" )\r\n",
" out = self.fc_out(out)\r\n",
" return out"
],
"execution_count": 33,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "pISz5k5V-0lB"
},
"source": [
"def save_checkpoint(state, filename=\"my_checkpoint.pth,tar\"):\r\n",
" print(\"---->Saving checkpoint\")\r\n",
" torch.save(state,filename)\r\n",
"\r\n",
"def load_checkpoint(stae, model, optimizer):\r\n",
" print(\"----->Loading checkpoint\")\r\n",
" model.load_state_dict(checkpoint[\"state_dict\"])\r\n",
" optimizer.load_state_dict(checkpoint[\"optimizer\"])\r\n",
"\r\n",
"\r\n",
"def translate_sentence(model, sentence, german, english, device, max_length=50):\r\n",
" #load german tokenizer\r\n",
" spacy_ger = spacy.load(\"de_core_news_sm\")\r\n",
" #create tokens in spacy and convert everything into lower case\r\n",
" if type(sentence) == str:\r\n",
" tokens = [token.text.lower() for token in spacy_ger(sentence)]\r\n",
" else:\r\n",
" tokens = [token.lower() for token in sentence]\r\n",
"\r\n",
" #Add <SOS> and <EOS> token in beginning and end\r\n",
" tokens.insert(0, german.init_token)\r\n",
" tokens.append(german.eos_token)\r\n",
"\r\n",
" #convert text to indices---->Numericalize them\r\n",
" text_to_indices = [german.vocab.stoi[tok] for tok in tokens]\r\n",
" #convert to tensors\r\n",
" sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)\r\n",
" outputs = [english.vocab.stoi[\"<sos>\"]]\r\n",
" for i in range(max_length):\r\n",
" trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)\r\n",
"\r\n",
" with torch.no_grad():\r\n",
" output = model(sentence_tensor, trg_tensor)\r\n",
"\r\n",
" best_pred = output.argmax(2)[-1, :].item()\r\n",
" outputs.append(best_pred)\r\n",
"\r\n",
" if best_pred == english.vocab.stoi[\"<eos>\"]:\r\n",
" break\r\n",
"\r\n",
" translated_sentence = [english.vocab.itos[idx] for idx in outputs]\r\n",
" #remove start token\r\n",
" return translated_sentence[1:]\r\n",
"\r\n",
"def blue(data, model, german, english, device):\r\n",
" target = []\r\n",
" outputs = []\r\n",
"\r\n",
" for example in data:\r\n",
" src = vars(example)[\"src\"]\r\n",
" trg = vars(example)[\"trg\"]\r\n",
"\r\n",
" prediction = translate_sentence(model,src,german,english,device)\r\n",
" prediction = prediction[:-1] #remove <eos> token\r\n",
"\r\n",
" target.append([trg])\r\n",
" outputs.append(prediction)\r\n",
"\r\n",
" return blue_score(outputs, targets)"
],
"execution_count": 34,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "H4mJoG1MyQ8e"
},
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n",
"\r\n",
"load_model = False\r\n",
"save_model = True\r\n",
"\r\n",
"#Training Hyperparameters\r\n",
"n_epochs = 10000\r\n",
"learning_rate = 3e-4\r\n",
"batch_size = 32\r\n",
"\r\n",
"#model hyperparameters\r\n",
"src_vocab_size = len(german.vocab)\r\n",
"trg_vocab_size = len(english.vocab)\r\n",
"embedding_size = 512\r\n",
"num_heads = 8\r\n",
"num_encoder_layers = 3\r\n",
"num_decoder_layers = 3\r\n",
"dropout = 0.10\r\n",
"max_len = 100\r\n",
"forward_expansion = 4\r\n",
"src_pad_idx = english.vocab.stoi[\"<pad>\"]"
],
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xmqgm0DWzq0f"
},
"source": [
"train_iterator,valid_iterator,test_iterator = BucketIterator.splits(\r\n",
" (train_data,valid_data,test_data),\r\n",
" batch_size=batch_size,\r\n",
" sort_within_batch=True,\r\n",
" sort_key=lambda x: len(x.src),\r\n",
" device=device\r\n",
")"
],
"execution_count": 36,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 340
},
"id": "ZTkPS6HS0OJz",
"outputId": "75628736-3d92-4a9a-ae49-454a8b94f6fb"
},
"source": [
"model = Transformer(\r\n",
" embedding_size,\r\n",
" src_vocab_size,\r\n",
" trg_vocab_size,\r\n",
" src_pad_idx,\r\n",
" num_heads,\r\n",
" num_encoder_layers,\r\n",
" num_decoder_layers,\r\n",
" forward_expansion,\r\n",
" dropout,\r\n",
" max_len,\r\n",
" device\r\n",
").to(device)"
],
"execution_count": 37,
"outputs": [
{
"output_type": "error",
"ename": "RuntimeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-37-a8b2a1a1c17a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mmax_len\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m ).to(device)\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 612\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 613\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 614\u001b[0m def register_backward_hook(\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 379\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 381\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 382\u001b[0m \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 608\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconvert_to_format\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconvert_to_format\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 610\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 612\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: CUDA error: device-side assert triggered"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WvpUgcRz0ekq"
},
"source": [
"def count_parameters(model):\r\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\r\n",
"\r\n",
"print(f\"The model has {count_parameters(model):,} trainable parameters\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QvoKGVza0p7d"
},
"source": [
"optimizer = optim.AdamW(model.parameters())\r\n",
"\r\n",
"schdeuler = optim.lr_scheduler.ReduceLROnPlateau(\r\n",
" optimizer, factor=0.1,patience=10, verbose=True\r\n",
")\r\n",
"pad_idx = english.vocab.stoi[\"<pad>\"]\r\n",
"criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "27vttEJ--r2t"
},
"source": [
"if load_model:\r\n",
" load_checkpoint(torch.load(\"my_checkpoint.pth.tar\"),model,optimizer)\r\n",
"\r\n",
"sentence = \"ein pferd geht unter einer brucke neben einem boot\"\r\n",
"\r\n",
"for epoch in range(n_epochs):\r\n",
" print(f\"[Epoch {epoch}/{n_epochs}\")\r\n",
"\r\n",
" if save_model:\r\n",
" checkpoint = {\r\n",
" \"state_dict\": model.state_dict(),\r\n",
" \"optimizer\" : optimizer.state_dict(),\r\n",
" }\r\n",
" save_checkpoint(checkpoint)\r\n",
"\r\n",
" model.eval()\r\n",
" translated_sentence = translate_sentence(\r\n",
" model, sentence, german, english, device, max_length=50\r\n",
" )\r\n",
" print(f\"Translated example sentence:\\n {translated_sentence}\")\r\n",
" model.train()\r\n",
" losses = []\r\n",
"\r\n",
" for batch_idx, data in enumerate(train_iterator):\r\n",
" #send inputs and targets to cuda\r\n",
" inp = data.src.to(device)\r\n",
" target = data.trg.to(device)\r\n",
" #forward prop\r\n",
" output = model(inp, target)\r\n",
"\r\n",
" #ouput shape--->[trg_len,batch_size,output_dim]\r\n",
" #reshape it appropriately for cross entropy loss\r\n",
" output = output.reshape(-1, output.shape[2])\r\n",
" target = target[1:].reshape(-1)\r\n",
"\r\n",
" optimizer.zero_grad\r\n",
" loss = criterion(output,target)\r\n",
" losses.append(loss.item())\r\n",
" #back prop\r\n",
" loss.backward()\r\n",
" #clip to avoid exploding gradients\r\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=1)\r\n",
" #gradient descent step\r\n",
" optimizer.step()\r\n",
"\r\n",
" mean_loss = sum(losses)/len(losses)\r\n",
" schdeuler.step(mean_loss)\r\n",
"\r\n",
"score = bleu(test_data[1:100], model, german, english, device)\r\n",
"print(f\"Blue score {score*100:.2f}\") "
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4Dzu7EtTCbbh"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment