-
-
Save YasuThompson/081e9fc61cb6c62b74a9417eb0034f5a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Decoder(tf.keras.Model):\n", | |
" def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", | |
" super(Decoder, self).__init__()\n", | |
" self.batch_sz = batch_sz\n", | |
" self.dec_units = dec_units\n", | |
" self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", | |
" self.gru = tf.keras.layers.GRU(self.dec_units,\n", | |
" return_sequences=True,\n", | |
" return_state=True,\n", | |
" recurrent_initializer='glorot_uniform')\n", | |
" self.fc = tf.keras.layers.Dense(vocab_size)\n", | |
"\n", | |
" # used for attention\n", | |
" self.attention = BahdanauAttention(self.dec_units)\n", | |
"\n", | |
" def call(self, x, hidden, enc_output):\n", | |
" '''\n", | |
" As well as 'Encoder' class, the shape of inputs of 'Decoder' is [batch, timesteps, feature]. \n", | |
" https://www.tensorflow.org/api_docs/python/tf/keras/layers/GRU\n", | |
" But you have to keep it in mind that you input a token every time step, the input is \n", | |
" (batch_size, 1, embedding_dim). \n", | |
" '''\n", | |
" \n", | |
" '''\n", | |
" You first calculate a 'context_vector' by comparing the hidden layer of the LAST cell, \n", | |
" with the outputs of the encoder because you use Bahdanau's additive style attention mechanism. \n", | |
" You usually use the hidden layer of the current cell.\n", | |
" \n", | |
" '''\n", | |
" # enc_output shape == (batch_size, max_length, hidden_size)\n", | |
" context_vector, attention_weights = self.attention(hidden, enc_output)\n", | |
"\n", | |
" '''\n", | |
" You combine the 'context_vector' with the embedding vector of the decoder input at this time step. \n", | |
" And the RNN cell at current time step gives out a predicted word, given the combined input. \n", | |
" '''\n", | |
" # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", | |
" x = self.embedding(x)\n", | |
" # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", | |
" x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", | |
" # passing the concatenated vector to the GRU\n", | |
" output, state = self.gru(x)\n", | |
"\n", | |
" # output shape == (batch_size * 1, hidden_size)\n", | |
" output = tf.reshape(output, (-1, output.shape[2]))\n", | |
" \n", | |
" # output shape == (batch_size, vocab)\n", | |
" x = self.fc(output)\n", | |
" '''\n", | |
" x: a vector whose dimension is the size of the output vocabulary size. The index of the maximum\n", | |
" element of this vector is the index of the predicted word at this time step. \n", | |
" state: the hidden state at this time step. This is the query of the next time step in Bahdanau's \n", | |
" additive style. \n", | |
" '''\n", | |
" return x, state, attention_weights\n", | |
"\n", | |
"decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment