Skip to content

Instantly share code, notes, and snippets.

@YasuThompson
Created February 18, 2021 10:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YasuThompson/afeaddc05867228fa73428245b3829d9 to your computer and use it in GitHub Desktop.
Save YasuThompson/afeaddc05867228fa73428245b3829d9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Inside 'MultiHeadAttention' class...\n",
"\n",
"The shape of 'q' is (1, 9, 512)\n",
"The shape of 'k' is (1, 9, 512)\n",
"The shape of 'v' is (1, 9, 512)\n",
"\n",
"After passing 'q', 'k', 'v' through densely connected layers....\n",
"The shape of 'q' is (1, 9, 512)\n",
"The shape of 'k' is (1, 9, 512)\n",
"The shape of 'v' is (1, 9, 512)\n",
"\n",
"After splitting the heads....\n",
"The shape of 'q' is (1, 8, 9, 64)\n",
"The shape of 'k' is (1, 8, 9, 64)\n",
"The shape of 'v' is (1, 8, 9, 64)\n",
"\n",
"The shape of 'attention_weights' is (1, 8, 9, 9)\n",
"The shape of 'scaled_attention' is (1, 8, 9, 64)\n",
"\n",
"After transposing....\n",
"The shape of 'scaled_attention' is (1, 9, 8, 64)\n",
"\n",
"The shape of 'concat_attention' is (1, 9, 512)\n",
"\n",
"The shape of 'output' is (1, 9, 512)\n"
]
}
],
"source": [
"temp_mha = MultiHeadAttention(d_model=512, num_heads=8)\n",
"sample_sentence = tf.random.uniform((1, 9, 512)) # (batch_size, encoder_sequence, d_model)\n",
"out, attn = temp_mha(v=sample_sentence, k=sample_sentence, q=sample_sentence, mask=None)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(TensorShape([1, 9, 512]), TensorShape([1, 8, 9, 9]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out.shape, attn.shape"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment