-
-
Save TaiToTo/ed2ca04f6a9b94034c3e2b40cf4b4e0c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Inside 'MultiHeadAttention' class...\n", | |
"\n", | |
"The shape of 'q' is (1, 12, 512)\n", | |
"The shape of 'k' is (1, 9, 512)\n", | |
"The shape of 'v' is (1, 9, 512)\n", | |
"\n", | |
"After passing 'q', 'k', 'v' through densely connected layers....\n", | |
"The shape of 'q' is (1, 12, 512)\n", | |
"The shape of 'k' is (1, 9, 512)\n", | |
"The shape of 'v' is (1, 9, 512)\n", | |
"\n", | |
"After splitting the heads....\n", | |
"The shape of 'q' is (1, 8, 12, 64)\n", | |
"The shape of 'k' is (1, 8, 9, 64)\n", | |
"The shape of 'v' is (1, 8, 9, 64)\n", | |
"\n", | |
"The shape of 'attention_weights' is (1, 8, 12, 9)\n", | |
"The shape of 'scaled_attention' is (1, 8, 12, 64)\n", | |
"\n", | |
"After transposing....\n", | |
"The shape of 'scaled_attention' is (1, 12, 8, 64)\n", | |
"\n", | |
"The shape of 'concat_attention' is (1, 12, 512)\n", | |
"\n", | |
"The shape of 'output' is (1, 12, 512)\n" | |
] | |
} | |
], | |
"source": [ | |
"# As I mentioned, \"queries\" can be in different language from \"keys\" or \"values.\"\n", | |
"# * They are supposed to be different in translation tasks. \n", | |
"\n", | |
"# In this case you compare \"quries\" in the target language, with the \"keys\" in the original language. \n", | |
"# And after that you reweight \"values\" in the original language. \n", | |
"\n", | |
"# Usually, the numbef or \"queries\" is different from that of \"keys\" or \"values.\" because \n", | |
"# translated sentences usually have different number of tokens. \n", | |
"\n", | |
"# Let's see an example where the number of input sentence is 9 and that of the translated sentence is 12. \n", | |
"temp_mha = MultiHeadAttention(d_model=512, num_heads=8)\n", | |
"sample_sentence_source_lang = tf.random.uniform((1, 9, 512)) # (batch_size, encoder_sequence, d_model)\n", | |
"sample_sentence_target_lang = tf.random.uniform((1, 12, 512)) # (batch_size, encoder_sequence, d_model)\n", | |
"out, attn = temp_mha(v=sample_sentence_source_lang, k=sample_sentence_source_lang, q=sample_sentence_target_lang, mask=None)\n", | |
"\n", | |
"# In the results below, you can see that you reweight the \"values\" in the original sentence with a (12, 9) sized matrix\n", | |
"# in each head, and the the size of the resulting 'scaled_attention' is (12, 64) in each head. \n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment