Skip to content

Instantly share code, notes, and snippets.

@bentrevett
Last active March 31, 2022 02:05
Show Gist options
  • Save bentrevett/1eaf9c512733345787c7fc7272a3f8ff to your computer and use it in GitHub Desktop.
Save bentrevett/1eaf9c512733345787c7fc7272a3f8ff to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using the Model for Inference\n",
"\n",
"- the only changes we need to make are to the `Seq2Seq` model\n",
"- if we pass a `trg` (target) of `None`, we now make a dummy target that is of length 25 (chosen arbitrarily)\n",
"- the dummy target is filled with `2`, as this is the index of the `<sos>` token in the `TRG` vocab (you can verify this with `TRG.vocab.stoi['<sos>']`)\n",
"- we never actually use the `trg` tensor as an input to the decoder as `teacher_forcing_ratio = 0` during inference, the only time it is used as an input to is the very first time step, where it inputs the `<sos>` token\n",
"- we add some functions at the end to translate a single sentence (read the comments below)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"from torchtext.datasets import TranslationDataset, Multi30k\n",
"from torchtext.data import Field, BucketIterator\n",
"\n",
"import spacy\n",
"\n",
"import random\n",
"import math\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"SEED = 1234\n",
"\n",
"random.seed(SEED)\n",
"torch.manual_seed(SEED)\n",
"torch.cuda.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"spacy_de = spacy.load('de')\n",
"spacy_en = spacy.load('en')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def tokenize_de(text):\n",
" \"\"\"\n",
" Tokenizes German text from a string into a list of strings and reverses it\n",
" \"\"\"\n",
" return [tok.text for tok in spacy_de.tokenizer(text)][::-1]\n",
"\n",
"def tokenize_en(text):\n",
" \"\"\"\n",
" Tokenizes English text from a string into a list of strings\n",
" \"\"\"\n",
" return [tok.text for tok in spacy_en.tokenizer(text)]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"SRC = Field(tokenize=tokenize_de, init_token='<sos>', eos_token='<eos>', lower=True)\n",
"TRG = Field(tokenize=tokenize_en, init_token='<sos>', eos_token='<eos>', lower=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train_data, valid_data, test_data = Multi30k.splits(exts=('.de', '.en'), fields=(SRC, TRG))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"SRC.build_vocab(train_data, min_freq=2)\n",
"TRG.build_vocab(train_data, min_freq=2)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 128\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n",
" (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class Encoder(nn.Module):\n",
" def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):\n",
" super().__init__()\n",
" \n",
" self.input_dim = input_dim\n",
" self.emb_dim = emb_dim\n",
" self.hid_dim = hid_dim\n",
" self.n_layers = n_layers\n",
" self.dropout = dropout\n",
" \n",
" self.embedding = nn.Embedding(input_dim, emb_dim)\n",
" \n",
" self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, src):\n",
" \n",
" #src = [sent len, batch size]\n",
" \n",
" embedded = self.dropout(self.embedding(src))\n",
" \n",
" #embedded = [sent len, batch size, emb dim]\n",
" \n",
" outputs, (hidden, cell) = self.rnn(embedded)\n",
" \n",
" #outputs = [sent len, batch size, hid dim * n directions]\n",
" #hidden = [n layers * n directions, batch size, hid dim]\n",
" #cell = [n layers * n directions, batch size, hid dim]\n",
" \n",
" #outputs are always from the top hidden layer\n",
" \n",
" return hidden, cell"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class Decoder(nn.Module):\n",
" def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):\n",
" super().__init__()\n",
"\n",
" self.emb_dim = emb_dim\n",
" self.hid_dim = hid_dim\n",
" self.output_dim = output_dim\n",
" self.n_layers = n_layers\n",
" self.dropout = dropout\n",
" \n",
" self.embedding = nn.Embedding(output_dim, emb_dim)\n",
" \n",
" self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)\n",
" \n",
" self.out = nn.Linear(hid_dim, output_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, input, hidden, cell):\n",
" \n",
" #input = [batch size]\n",
" #hidden = [n layers * n directions, batch size, hid dim]\n",
" #cell = [n layers * n directions, batch size, hid dim]\n",
" \n",
" #n directions in the decoder will both always be 1, therefore:\n",
" #hidden = [n layers, batch size, hid dim]\n",
" #context = [n layers, batch size, hid dim]\n",
" \n",
" input = input.unsqueeze(0)\n",
" \n",
" #input = [1, batch size]\n",
" \n",
" embedded = self.dropout(self.embedding(input))\n",
" \n",
" #embedded = [1, batch size, emb dim]\n",
" \n",
" output, (hidden, cell) = self.rnn(embedded, (hidden, cell))\n",
" \n",
" #output = [sent len, batch size, hid dim * n directions]\n",
" #hidden = [n layers * n directions, batch size, hid dim]\n",
" #cell = [n layers * n directions, batch size, hid dim]\n",
" \n",
" #sent len and n directions will always be 1 in the decoder, therefore:\n",
" #output = [1, batch size, hid dim]\n",
" #hidden = [n layers, batch size, hid dim]\n",
" #cell = [n layers, batch size, hid dim]\n",
" \n",
" prediction = self.out(output.squeeze(0))\n",
" \n",
" #prediction = [batch size, output dim]\n",
" \n",
" return prediction, hidden, cell"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"class Seq2Seq(nn.Module):\n",
" def __init__(self, encoder, decoder, device):\n",
" super().__init__()\n",
" \n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.device = device\n",
" \n",
" assert encoder.hid_dim == decoder.hid_dim, \"Hidden dimensions of encoder and decoder must be equal!\"\n",
" assert encoder.n_layers == decoder.n_layers, \"Encoder and decoder must have equal number of layers!\"\n",
" \n",
" def forward(self, src, trg=None, teacher_forcing_ratio=0.5):\n",
" \n",
" if trg is None:\n",
" trg = torch.zeros((25, src.shape[1])).fill_(2).long().to(src.device)\n",
" assert teacher_forcing_ratio == 0, \"Must be zero during inference\"\n",
" \n",
" batch_size = trg.shape[1]\n",
" max_len = trg.shape[0]\n",
" trg_vocab_size = self.decoder.output_dim\n",
" \n",
" #tensor to store decoder outputs\n",
" outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)\n",
" \n",
" #last hidden state of the encoder is used as the initial hidden state of the decoder\n",
" hidden, cell = self.encoder(src)\n",
" \n",
" #first input to the decoder is the <sos> tokens\n",
" input = trg[0,:]\n",
" \n",
" for t in range(1, max_len):\n",
" \n",
" output, hidden, cell = self.decoder(input, hidden, cell)\n",
" outputs[t] = output\n",
" teacher_force = random.random() < teacher_forcing_ratio\n",
" top1 = output.max(1)[1]\n",
" input = (trg[t] if teacher_force else top1)\n",
" \n",
" return outputs"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"INPUT_DIM = len(SRC.vocab)\n",
"OUTPUT_DIM = len(TRG.vocab)\n",
"ENC_EMB_DIM = 256\n",
"DEC_EMB_DIM = 256\n",
"HID_DIM = 512\n",
"N_LAYERS = 2\n",
"ENC_DROPOUT = 0.5\n",
"DEC_DROPOUT = 0.5\n",
"\n",
"enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)\n",
"dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)\n",
"\n",
"model = Seq2Seq(enc, dec, device).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"pad_idx = TRG.vocab.stoi['<pad>']\n",
"\n",
"criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def train(model, iterator, optimizer, criterion, clip):\n",
" \n",
" model.train()\n",
" \n",
" epoch_loss = 0\n",
" \n",
" for i, batch in enumerate(iterator):\n",
" \n",
" src = batch.src\n",
" trg = batch.trg\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" output = model(src, trg)\n",
" \n",
" #trg = [sent len, batch size]\n",
" #output = [sent len, batch size, output dim]\n",
" \n",
" #reshape to:\n",
" #trg = [(sent len - 1) * batch size]\n",
" #output = [(sent len - 1) * batch size, output dim]\n",
" \n",
" loss = criterion(output[1:].view(-1, output.shape[2]), trg[1:].view(-1))\n",
" \n",
" loss.backward()\n",
" \n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
" \n",
" optimizer.step()\n",
" \n",
" epoch_loss += loss.item()\n",
" \n",
" return epoch_loss / len(iterator)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(model, iterator, criterion):\n",
" \n",
" model.eval()\n",
" \n",
" epoch_loss = 0\n",
" \n",
" with torch.no_grad():\n",
" \n",
" for i, batch in enumerate(iterator):\n",
"\n",
" src = batch.src\n",
" trg = batch.trg\n",
"\n",
" output = model(src, trg, 0) #turn off teacher forcing\n",
"\n",
" loss = criterion(output[1:].view(-1, output.shape[2]), trg[1:].view(-1))\n",
"\n",
" epoch_loss += loss.item()\n",
" \n",
" return epoch_loss / len(iterator)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| Epoch: 001 | Train Loss: 5.039 | Train PPL: 154.308 | Val. Loss: 4.836 | Val. PPL: 125.907 |\n",
"| Epoch: 002 | Train Loss: 4.491 | Train PPL: 89.214 | Val. Loss: 4.699 | Val. PPL: 109.889 |\n",
"| Epoch: 003 | Train Loss: 4.220 | Train PPL: 68.064 | Val. Loss: 4.480 | Val. PPL: 88.226 |\n",
"| Epoch: 004 | Train Loss: 3.999 | Train PPL: 54.570 | Val. Loss: 4.341 | Val. PPL: 76.763 |\n",
"| Epoch: 005 | Train Loss: 3.801 | Train PPL: 44.730 | Val. Loss: 4.194 | Val. PPL: 66.307 |\n",
"| Epoch: 006 | Train Loss: 3.641 | Train PPL: 38.134 | Val. Loss: 4.058 | Val. PPL: 57.859 |\n",
"| Epoch: 007 | Train Loss: 3.512 | Train PPL: 33.529 | Val. Loss: 3.930 | Val. PPL: 50.892 |\n",
"| Epoch: 008 | Train Loss: 3.382 | Train PPL: 29.439 | Val. Loss: 3.852 | Val. PPL: 47.078 |\n",
"| Epoch: 009 | Train Loss: 3.267 | Train PPL: 26.223 | Val. Loss: 3.802 | Val. PPL: 44.805 |\n",
"| Epoch: 010 | Train Loss: 3.193 | Train PPL: 24.373 | Val. Loss: 3.762 | Val. PPL: 43.038 |\n",
"| Epoch: 011 | Train Loss: 3.091 | Train PPL: 21.991 | Val. Loss: 3.743 | Val. PPL: 42.234 |\n",
"| Epoch: 012 | Train Loss: 3.012 | Train PPL: 20.328 | Val. Loss: 3.697 | Val. PPL: 40.311 |\n",
"| Epoch: 013 | Train Loss: 2.939 | Train PPL: 18.904 | Val. Loss: 3.641 | Val. PPL: 38.128 |\n",
"| Epoch: 014 | Train Loss: 2.870 | Train PPL: 17.630 | Val. Loss: 3.648 | Val. PPL: 38.389 |\n",
"| Epoch: 015 | Train Loss: 2.833 | Train PPL: 17.004 | Val. Loss: 3.643 | Val. PPL: 38.224 |\n",
"| Epoch: 016 | Train Loss: 2.762 | Train PPL: 15.835 | Val. Loss: 3.616 | Val. PPL: 37.174 |\n",
"| Epoch: 017 | Train Loss: 2.700 | Train PPL: 14.876 | Val. Loss: 3.617 | Val. PPL: 37.233 |\n",
"| Epoch: 018 | Train Loss: 2.654 | Train PPL: 14.216 | Val. Loss: 3.614 | Val. PPL: 37.098 |\n",
"| Epoch: 019 | Train Loss: 2.598 | Train PPL: 13.443 | Val. Loss: 3.602 | Val. PPL: 36.689 |\n",
"| Epoch: 020 | Train Loss: 2.524 | Train PPL: 12.474 | Val. Loss: 3.607 | Val. PPL: 36.861 |\n",
"| Epoch: 021 | Train Loss: 2.491 | Train PPL: 12.075 | Val. Loss: 3.560 | Val. PPL: 35.148 |\n",
"| Epoch: 022 | Train Loss: 2.443 | Train PPL: 11.505 | Val. Loss: 3.599 | Val. PPL: 36.552 |\n",
"| Epoch: 023 | Train Loss: 2.385 | Train PPL: 10.860 | Val. Loss: 3.586 | Val. PPL: 36.101 |\n",
"| Epoch: 024 | Train Loss: 2.349 | Train PPL: 10.477 | Val. Loss: 3.635 | Val. PPL: 37.894 |\n",
"| Epoch: 025 | Train Loss: 2.305 | Train PPL: 10.021 | Val. Loss: 3.537 | Val. PPL: 34.362 |\n"
]
}
],
"source": [
"N_EPOCHS = 25\n",
"CLIP = 10\n",
"SAVE_DIR = 'models'\n",
"MODEL_SAVE_PATH = os.path.join(SAVE_DIR, 'tut1_model.pt')\n",
"\n",
"best_valid_loss = float('inf')\n",
"\n",
"if not os.path.isdir(f'{SAVE_DIR}'):\n",
" os.makedirs(f'{SAVE_DIR}')\n",
"\n",
"for epoch in range(N_EPOCHS):\n",
" \n",
" train_loss = train(model, train_iterator, optimizer, criterion, CLIP)\n",
" valid_loss = evaluate(model, valid_iterator, criterion)\n",
" \n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), MODEL_SAVE_PATH)\n",
" \n",
" print(f'| Epoch: {epoch+1:03} | Train Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f} | Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f} |')"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| Test Loss: 3.544 | Test PPL: 34.599 |\n"
]
}
],
"source": [
"model.load_state_dict(torch.load(MODEL_SAVE_PATH))\n",
"\n",
"test_loss = evaluate(model, test_iterator, criterion)\n",
"\n",
"print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Below is a function to translate a sentence.**\n",
"\n",
"- tokenize the German using the same tokenization for the dataset\n",
"- convert the tokens into indexes using `SRC.vocab.stoi` (stoi = string to int)\n",
"- convert the tokens into a tensor, place on the GPU and add a batch dimension\n",
"- feed the tensor into our model, making sure to pass a `trg` of `None` and `teacher_forcing_ratio = 0`\n",
"- convert the translation probabilities into predictions taking the highest predicted output word with `torch.argmax`\n",
"- convert the predictions into words using `TRG.vocab.itos` (itos = int to string), remembering to ignore the first token as `output[0]` is always zero"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"def translate_sentence(sentence):\n",
" tokenized = tokenize_de(sentence) #tokenize sentence\n",
" numericalized = [SRC.vocab.stoi[t] for t in tokenized] #convert tokens into indexes\n",
" tensor = torch.LongTensor(numericalized).unsqueeze(1).to(device) #convert to tensor and add batch dimension\n",
" translation_tensor_probs = model(tensor, None, 0).squeeze(1) #pass through model to get translation probabilities\n",
" translation_tensor = torch.argmax(translation_tensor_probs, 1) #get translation from highest probabilities\n",
" translation = [TRG.vocab.itos[t] for t in translation_tensor][1:] #we ignore the first token, just like we do in the training loop\n",
" return translation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's print out a candidate German sentence and its corresponding translation. \n",
"\n",
"We get this first example from the training set, so the model should do a good job at translating it.\n",
"\n",
"Note, we have to put the `candidate` in forwards as our `translate_sentence` function will reverse it for us."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"zwei junge weiße männer sind im freien in der nähe vieler büsche .\n",
"two young , white males are outside near many bushes .\n"
]
}
],
"source": [
"candidate = ' '.join(vars(train_data.examples[0])['src'][::-1])\n",
"candidate_translation = ' '.join(vars(train_data.examples[0])['trg'])\n",
"\n",
"print(candidate)\n",
"print(candidate_translation)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['two',\n",
" 'young',\n",
" 'men',\n",
" 'are',\n",
" 'outside',\n",
" 'near',\n",
" 'near',\n",
" 'trees',\n",
" '.',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>']"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"translate_sentence(candidate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not bad, didn't mention the men were \"white\", got \"trees\" and \"bushes\" mixed up, and repeated \"near\".\n",
"\n",
"We'll repeat this with an example from the validation set. \n",
"\n",
"This will test how well the model has generalized to examples it has not been trained on."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"eine gruppe von männern lädt baumwolle auf einen lastwagen\n",
"a group of men are loading cotton onto a truck\n"
]
}
],
"source": [
"candidate = ' '.join(vars(valid_data.examples[0])['src'][::-1])\n",
"candidate_translation = ' '.join(vars(valid_data.examples[0])['trg'])\n",
"\n",
"print(candidate)\n",
"print(candidate_translation)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['a',\n",
" 'group',\n",
" 'of',\n",
" 'men',\n",
" 'men',\n",
" 'are',\n",
" 'a',\n",
" 'a',\n",
" 'a',\n",
" '.',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>']"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"translate_sentence(candidate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not great, potentially 'cotton' and 'truck' didn't make it into our vocabulary so it had issues translating those words and got confused.\n",
"\n",
"Finally, we'll check how it does on an example from the test set."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ein mann mit einem orangefarbenen hut , der etwas anstarrt .\n",
"a man in an orange hat starring at something .\n"
]
}
],
"source": [
"candidate = ' '.join(vars(test_data.examples[0])['src'][::-1])\n",
"candidate_translation = ' '.join(vars(test_data.examples[0])['trg'])\n",
"\n",
"print(candidate)\n",
"print(candidate_translation)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['a',\n",
" 'man',\n",
" 'in',\n",
" 'a',\n",
" 'orange',\n",
" 'hat',\n",
" 'is',\n",
" 'looking',\n",
" 'at',\n",
" 'something',\n",
" '.',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>']"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"translate_sentence(candidate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not perfect, but it's pretty good! The only mistake is it output \"a\" instead \"an\". Interestingly, \"starring\" is a typo and it should have been \"staring\", but somehow our model managed to output the sensibile \"looking\". \n",
"\n",
"OK, now let's check how well it does on the example sentence you posted."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"candidate = 'das ist das Ergebnis meines Modells'"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['the',\n",
" '<unk>',\n",
" '<unk>',\n",
" '<unk>',\n",
" '<unk>',\n",
" '<unk>',\n",
" '.',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>',\n",
" '<eos>']"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"translate_sentence(candidate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Well, that's pretty awful. I believe the issues are due to the majority of the sentence being tokens outside of the vocabulary.\n",
"\n",
"We can check this by passing the candidate sentence through the vocabulary and then back again, checking how many words it has converted into `<unk>` tokens."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['das', 'ist', 'das', '<unk>', '<unk>', '<unk>']\n"
]
}
],
"source": [
"tokenized = tokenize_de(candidate) \n",
"numericalized = [SRC.vocab.stoi[t] for t in tokenized] \n",
"back_to_candidate = [SRC.vocab.itos[n] for n in numericalized][::-1] #need to reverse again as tokenization reverses sentence\n",
"print(back_to_candidate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see, \"Ergebnis meines Modells\" is all out of vocabulary, which causes the translation to be poor, although it's quite surprising it didn't pick up \"das ist das\" translates to \"this is the\"."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment