Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@YasuThompson
Created March 17, 2021 15:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YasuThompson/333643cda9f7a12dc1eefe5e6f9e8802 to your computer and use it in GitHub Desktop.
Save YasuThompson/333643cda9f7a12dc1eefe5e6f9e8802 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(inp_sentence):\n",
" '''\n",
" The <start> token is [10000], \n",
" and the <end> token is [10001].\n",
" '''\n",
" start_token = [vocab_size]\n",
" end_token = [vocab_size + 1]\n",
"\n",
" '''\n",
" You first encode an input sentences into a tensor with integers. \n",
" '''\n",
" inp_sentence = start_token + bpemb_1.encode_ids(inp_sentence) + end_token\n",
" encoder_input = tf.expand_dims(inp_sentence, 0)\n",
"\n",
" '''\n",
" The translated output is first [10000], which means <start>\n",
" '''\n",
" decoder_input = [vocab_size]\n",
" output = tf.expand_dims(decoder_input, 0)\n",
"\n",
" '''\n",
" In this loop, for MAX_LENGTH times at most, you repeat 10002-class classification, \n",
" that is , you choose one word every loop and append it to the 'output'.\n",
" When you finish MAX_LENGTH loops, or you choose [10001] as a classification result, \n",
" which means <end> token, you stop decoding. \n",
" '''\n",
" \n",
" '''\n",
" During training Transformer-based translators, you put in the whole target sentences, \n",
" but you need to simualte this loop even during training. \n",
" That is why you need look ahead mask during training to hide the upcoming tokens\n",
" which are not decoded yet.\n",
" '''\n",
" for i in range(MAX_LENGTH):\n",
" enc_padding_mask, combined_mask, dec_padding_mask = create_masks(\n",
" encoder_input, output)\n",
"\n",
" # predictions.shape == (batch_size, seq_len, vocab_size)\n",
" predictions, attention_weights_encoder, attention_weights_decoder = transformer(encoder_input, \n",
" output,\n",
" False,\n",
" enc_padding_mask,\n",
" combined_mask,\n",
" dec_padding_mask)\n",
"\n",
" # select the last word from the seq_len dimension\n",
" predictions = predictions[: ,-1:, :] # (batch_size, 1, vocab_size)\n",
"\n",
" predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)\n",
"\n",
" # return the result if the predicted_id is equal to the end token\n",
" if predicted_id == vocab_size+1:\n",
" return tf.squeeze(output, axis=0), attention_weights_encoder, attention_weights_decoder\n",
"\n",
" # concatentate the predicted_id to the output which is given to the decoder\n",
" # as its input.\n",
" output = tf.concat([output, predicted_id], axis=-1)\n",
"\n",
" return tf.squeeze(output, axis=0), attention_weights_encoder, attention_weights_decoder\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment