Skip to content

Instantly share code, notes, and snippets.

@YasuThompson
Created January 27, 2021 13:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YasuThompson/e73ae58cf2257d898f2772887b8e1253 to your computer and use it in GitHub Desktop.
Save YasuThompson/e73ae58cf2257d898f2772887b8e1253 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"@tf.function\n",
"def train_step(inp, targ, enc_hidden):\n",
" loss = 0\n",
"\n",
" '''\n",
" You input a (batch size, max input length) (=(64, 16)) tensor as an input\n",
" and a (batch size, max output length) (=(64, 11)) as an output, and get a loss. \n",
" '''\n",
" \n",
" with tf.GradientTape() as tape:\n",
" ''' \n",
" You put a batch of input sentences as a (64, 16) tensor. \n",
" '''\n",
" enc_output, enc_hidden = encoder(inp, enc_hidden)\n",
"\n",
" '''\n",
" You pass the last hidden state/vector of the encoder to the decoder as its \n",
" inittial layer.\n",
" '''\n",
" \n",
" dec_hidden = enc_hidden\n",
"\n",
" dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)\n",
" \n",
" '''\n",
" In the encoder part you pass the whole sentence as an input, \n",
" whereas in the decoder part, you pass a word every time step in the loop below. \n",
" '''\n",
" \n",
" '''\n",
" The loop below shows that you \n",
" '''\n",
" # Teacher forcing - feeding the target as the next input\n",
" for t in range(1, targ.shape[1]):\n",
" \n",
" # passing enc_output to the decoder\n",
" predictions, dec_hidden, _ = decoder(dec_input, \n",
" dec_hidden, \n",
" enc_output) # You need encoder outputs to calculate attentions. \n",
"\n",
" loss += loss_function(targ[:, t], predictions)\n",
"\n",
" # using teacher forcing\n",
" dec_input = tf.expand_dims(targ[:, t], 1)\n",
"\n",
" batch_loss = (loss / int(targ.shape[1]))\n",
"\n",
" '''\n",
" Updating the weigths with the three lines below. \n",
" '''\n",
" variables = encoder.trainable_variables + decoder.trainable_variables\n",
" gradients = tape.gradient(loss, variables)\n",
" optimizer.apply_gradients(zip(gradients, variables))\n",
"\n",
" return batch_loss\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment