Skip to content

Instantly share code, notes, and snippets.

@YasuThompson
Created January 27, 2021 13:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YasuThompson/9bf956a189e71778aecaddbfd50bf23d to your computer and use it in GitHub Desktop.
Save YasuThompson/9bf956a189e71778aecaddbfd50bf23d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"class BahdanauAttention(tf.keras.layers.Layer):\n",
" def __init__(self, units):\n",
" super(BahdanauAttention, self).__init__()\n",
" self.W1 = tf.keras.layers.Dense(units)\n",
" self.W2 = tf.keras.layers.Dense(units)\n",
" self.V = tf.keras.layers.Dense(1)\n",
"\n",
" def call(self, query, values):\n",
" '''\n",
" In the decoder part, you get an embedding vector for an input token every time step. \n",
" You have to calculate attentions using this class at EVERY TIME STEP. \n",
" 'query' is the hidden state of an RNN cell at the time in the decoder part, whose size is (batch_size, 1, 1024).\n",
" 'values' is the outputs of the encoder part, whose size is (batch_size, 16, 1024). \n",
" (*The length of the input is not necessarily 16.)\n",
" \n",
" Attention mehcanism calculates relevances of a query and values with a certain function. \n",
" There are several functions for calculating the relevances, and in this implementation \n",
" we use Bahdanau\"s additive style. \n",
" '''\n",
" \n",
" # query hidden state shape == (batch_size, hidden size)\n",
" # query_with_time_axis shape == (batch_size, 1, hidden size)\n",
" # values shape == (batch_size, max_len, hidden size)\n",
" # we are doing this to broadcast addition along the time axis to calculate the score\n",
" \n",
" '''\n",
" In this implementation, you always need to consider time steps. \n",
" '''\n",
" query_with_time_axis = tf.expand_dims(query, 1)\n",
"\n",
" '''\n",
" You get the attentions between the query and outputs of the encoder below.\n",
" In short, you compare the a word in the decoder with the input. \n",
" This is equivalent to finding the corresponding words in the original language, \n",
" when you are going to write a word in the target language. \n",
" '''\n",
" # score shape == (batch_size, max_length, 1)\n",
" # we get 1 at the last axis because we are applying score to self.V\n",
" # the shape of the tensor before applying self.V is (batch_size, max_length, units)\n",
" score = self.V(tf.nn.tanh(\n",
" self.W1(values) + self.W2(query_with_time_axis)))\n",
"\n",
" '''\n",
" You normalize the score calculated above with a softmax function so that the usm of its values is 1. \n",
" '''\n",
" # attention_weights shape == (batch_size, max_length, 1)\n",
" attention_weights = tf.nn.softmax(score, axis=1)\n",
"\n",
" # context_vector shape after sum == (batch_size, hidden_size)\n",
" '''\n",
" You reweight the outputs of the encoder with attention scores.\n",
" The shape of the resulitng 'context_vector' is (64, 16, 1024)\n",
" '''\n",
" context_vector = attention_weights * values \n",
" '''\n",
" You calculate the weighted average of the reweighted vectors above. \n",
" Thus the size of the shape of the resulting 'context_vector' is (64, 1024). \n",
" '''\n",
" context_vector = tf.reduce_sum(context_vector, axis=1) # You take a weighted average of c\n",
"\n",
" return context_vector, attention_weights"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment