This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def evaluate(inp_sentence):\n", | |
" '''\n", | |
" The <start> token is [10000], \n", | |
" and the <end> token is [10001].\n", | |
" '''\n", | |
" start_token = [vocab_size]\n", | |
" end_token = [vocab_size + 1]\n", | |
"\n", | |
" '''\n", | |
" You first encode an input sentences into a tensor with integers. \n", | |
" '''\n", | |
" inp_sentence = start_token + bpemb_1.encode_ids(inp_sentence) + end_token\n", | |
" encoder_input = tf.expand_dims(inp_sentence, 0)\n", | |
"\n", | |
" '''\n", | |
" The translated output is first [10000], which means <start>\n", | |
" '''\n", | |
" decoder_input = [vocab_size]\n", | |
" output = tf.expand_dims(decoder_input, 0)\n", | |
"\n", | |
" '''\n", | |
" In this loop, for MAX_LENGTH times at most, you repeat 10002-class classification, \n", | |
" that is , you choose one word every loop and append it to the 'output'.\n", | |
" When you finish MAX_LENGTH loops, or you choose [10001] as a classification result, \n", | |
" which means <end> token, you stop decoding. \n", | |
" '''\n", | |
" \n", | |
" '''\n", | |
" During training Transformer-based translators, you put in the whole target sentences, \n", | |
" but you need to simualte this loop even during training. \n", | |
" That is why you need look ahead mask during training to hide the upcoming tokens\n", | |
" which are not decoded yet.\n", | |
" '''\n", | |
" for i in range(MAX_LENGTH):\n", | |
" enc_padding_mask, combined_mask, dec_padding_mask = create_masks(\n", | |
" encoder_input, output)\n", | |
"\n", | |
" # predictions.shape == (batch_size, seq_len, vocab_size)\n", | |
" predictions, attention_weights_encoder, attention_weights_decoder = transformer(encoder_input, \n", | |
" output,\n", | |
" False,\n", | |
" enc_padding_mask,\n", | |
" combined_mask,\n", | |
" dec_padding_mask)\n", | |
"\n", | |
" # select the last word from the seq_len dimension\n", | |
" predictions = predictions[: ,-1:, :] # (batch_size, 1, vocab_size)\n", | |
"\n", | |
" predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)\n", | |
"\n", | |
" # return the result if the predicted_id is equal to the end token\n", | |
" if predicted_id == vocab_size+1:\n", | |
" return tf.squeeze(output, axis=0), attention_weights_encoder, attention_weights_decoder\n", | |
"\n", | |
" # concatentate the predicted_id to the output which is given to the decoder\n", | |
" # as its input.\n", | |
" output = tf.concat([output, predicted_id], axis=-1)\n", | |
"\n", | |
" return tf.squeeze(output, axis=0), attention_weights_encoder, attention_weights_decoder\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment