Skip to content

Instantly share code, notes, and snippets.

@YasuThompson
Created February 18, 2021 10:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YasuThompson/ed2ca04f6a9b94034c3e2b40cf4b4e0c to your computer and use it in GitHub Desktop.
Save YasuThompson/ed2ca04f6a9b94034c3e2b40cf4b4e0c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Inside 'MultiHeadAttention' class...\n",
"\n",
"The shape of 'q' is (1, 12, 512)\n",
"The shape of 'k' is (1, 9, 512)\n",
"The shape of 'v' is (1, 9, 512)\n",
"\n",
"After passing 'q', 'k', 'v' through densely connected layers....\n",
"The shape of 'q' is (1, 12, 512)\n",
"The shape of 'k' is (1, 9, 512)\n",
"The shape of 'v' is (1, 9, 512)\n",
"\n",
"After splitting the heads....\n",
"The shape of 'q' is (1, 8, 12, 64)\n",
"The shape of 'k' is (1, 8, 9, 64)\n",
"The shape of 'v' is (1, 8, 9, 64)\n",
"\n",
"The shape of 'attention_weights' is (1, 8, 12, 9)\n",
"The shape of 'scaled_attention' is (1, 8, 12, 64)\n",
"\n",
"After transposing....\n",
"The shape of 'scaled_attention' is (1, 12, 8, 64)\n",
"\n",
"The shape of 'concat_attention' is (1, 12, 512)\n",
"\n",
"The shape of 'output' is (1, 12, 512)\n"
]
}
],
"source": [
"# As I mentioned, \"queries\" can be in different language from \"keys\" or \"values.\"\n",
"# * They are supposed to be different in translation tasks. \n",
"\n",
"# In this case you compare \"quries\" in the target language, with the \"keys\" in the original language. \n",
"# And after that you reweight \"values\" in the original language. \n",
"\n",
"# Usually, the numbef or \"queries\" is different from that of \"keys\" or \"values.\" because \n",
"# translated sentences usually have different number of tokens. \n",
"\n",
"# Let's see an example where the number of input sentence is 9 and that of the translated sentence is 12. \n",
"temp_mha = MultiHeadAttention(d_model=512, num_heads=8)\n",
"sample_sentence_source_lang = tf.random.uniform((1, 9, 512)) # (batch_size, encoder_sequence, d_model)\n",
"sample_sentence_target_lang = tf.random.uniform((1, 12, 512)) # (batch_size, encoder_sequence, d_model)\n",
"out, attn = temp_mha(v=sample_sentence_source_lang, k=sample_sentence_source_lang, q=sample_sentence_target_lang, mask=None)\n",
"\n",
"# In the results below, you can see that you reweight the \"values\" in the original sentence with a (12, 9) sized matrix\n",
"# in each head, and the the size of the resulting 'scaled_attention' is (12, 64) in each head. \n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment