-
-
Save YasuThompson/e73ae58cf2257d898f2772887b8e1253 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@tf.function\n", | |
"def train_step(inp, targ, enc_hidden):\n", | |
" loss = 0\n", | |
"\n", | |
" '''\n", | |
" You input a (batch size, max input length) (=(64, 16)) tensor as an input\n", | |
" and a (batch size, max output length) (=(64, 11)) as an output, and get a loss. \n", | |
" '''\n", | |
" \n", | |
" with tf.GradientTape() as tape:\n", | |
" ''' \n", | |
" You put a batch of input sentences as a (64, 16) tensor. \n", | |
" '''\n", | |
" enc_output, enc_hidden = encoder(inp, enc_hidden)\n", | |
"\n", | |
" '''\n", | |
" You pass the last hidden state/vector of the encoder to the decoder as its \n", | |
" inittial layer.\n", | |
" '''\n", | |
" \n", | |
" dec_hidden = enc_hidden\n", | |
"\n", | |
" dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)\n", | |
" \n", | |
" '''\n", | |
" In the encoder part you pass the whole sentence as an input, \n", | |
" whereas in the decoder part, you pass a word every time step in the loop below. \n", | |
" '''\n", | |
" \n", | |
" '''\n", | |
" The loop below shows that you \n", | |
" '''\n", | |
" # Teacher forcing - feeding the target as the next input\n", | |
" for t in range(1, targ.shape[1]):\n", | |
" \n", | |
" # passing enc_output to the decoder\n", | |
" predictions, dec_hidden, _ = decoder(dec_input, \n", | |
" dec_hidden, \n", | |
" enc_output) # You need encoder outputs to calculate attentions. \n", | |
"\n", | |
" loss += loss_function(targ[:, t], predictions)\n", | |
"\n", | |
" # using teacher forcing\n", | |
" dec_input = tf.expand_dims(targ[:, t], 1)\n", | |
"\n", | |
" batch_loss = (loss / int(targ.shape[1]))\n", | |
"\n", | |
" '''\n", | |
" Updating the weigths with the three lines below. \n", | |
" '''\n", | |
" variables = encoder.trainable_variables + decoder.trainable_variables\n", | |
" gradients = tape.gradient(loss, variables)\n", | |
" optimizer.apply_gradients(zip(gradients, variables))\n", | |
"\n", | |
" return batch_loss\n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment