Skip to content

Instantly share code, notes, and snippets.

@YasuThompson
Created February 18, 2021 10:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YasuThompson/143a6ec357792764e65e7bb80f393e69 to your computer and use it in GitHub Desktop.
Save YasuThompson/143a6ec357792764e65e7bb80f393e69 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# In the split_heads() funtion in MultiHeadAttention() class, \n",
"# you reshape and transpose the input sentence. \n",
"# This process corresponds to the process of splitting the \n",
"# input sentence into 8 heads. \n",
"sample_sentence = tf.reshape(sample_sentence, (1, 9, 8, 64))\n",
"sample_sentence = tf.transpose(sample_sentence, perm=[0, 2, 1, 3])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 8, 9, 64)\n"
]
}
],
"source": [
"# This is the shape of 'sample_sentence' after splitting into \n",
"# 8 heads. The parts of (9, 64) sized matrix correspond to heads \n",
"# in each color. \n",
"print(sample_sentence.shape)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(9, 64)\n"
]
}
],
"source": [
"# The matrix below corresponds to the blue head.\n",
"print(sample_sentence[0][0].shape)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[ 1 2 3 4 5 6 7 8 9 10 11 12 13 14\n",
" 15 16 17 18 19 20 21 22 23 24 25 26 27 28\n",
" 29 30 31 32 33 34 35 36 37 38 39 40 41 42\n",
" 43 44 45 46 47 48 49 50 51 52 53 54 55 56\n",
" 57 58 59 60 61 62 63 64]\n",
" [ 513 514 515 516 517 518 519 520 521 522 523 524 525 526\n",
" 527 528 529 530 531 532 533 534 535 536 537 538 539 540\n",
" 541 542 543 544 545 546 547 548 549 550 551 552 553 554\n",
" 555 556 557 558 559 560 561 562 563 564 565 566 567 568\n",
" 569 570 571 572 573 574 575 576]\n",
" [1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038\n",
" 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052\n",
" 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066\n",
" 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080\n",
" 1081 1082 1083 1084 1085 1086 1087 1088]\n",
" [1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550\n",
" 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564\n",
" 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578\n",
" 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592\n",
" 1593 1594 1595 1596 1597 1598 1599 1600]\n",
" [2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062\n",
" 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076\n",
" 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090\n",
" 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104\n",
" 2105 2106 2107 2108 2109 2110 2111 2112]\n",
" [2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574\n",
" 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588\n",
" 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602\n",
" 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616\n",
" 2617 2618 2619 2620 2621 2622 2623 2624]\n",
" [3073 3074 3075 3076 3077 3078 3079 3080 3081 3082 3083 3084 3085 3086\n",
" 3087 3088 3089 3090 3091 3092 3093 3094 3095 3096 3097 3098 3099 3100\n",
" 3101 3102 3103 3104 3105 3106 3107 3108 3109 3110 3111 3112 3113 3114\n",
" 3115 3116 3117 3118 3119 3120 3121 3122 3123 3124 3125 3126 3127 3128\n",
" 3129 3130 3131 3132 3133 3134 3135 3136]\n",
" [3585 3586 3587 3588 3589 3590 3591 3592 3593 3594 3595 3596 3597 3598\n",
" 3599 3600 3601 3602 3603 3604 3605 3606 3607 3608 3609 3610 3611 3612\n",
" 3613 3614 3615 3616 3617 3618 3619 3620 3621 3622 3623 3624 3625 3626\n",
" 3627 3628 3629 3630 3631 3632 3633 3634 3635 3636 3637 3638 3639 3640\n",
" 3641 3642 3643 3644 3645 3646 3647 3648]\n",
" [4097 4098 4099 4100 4101 4102 4103 4104 4105 4106 4107 4108 4109 4110\n",
" 4111 4112 4113 4114 4115 4116 4117 4118 4119 4120 4121 4122 4123 4124\n",
" 4125 4126 4127 4128 4129 4130 4131 4132 4133 4134 4135 4136 4137 4138\n",
" 4139 4140 4141 4142 4143 4144 4145 4146 4147 4148 4149 4150 4151 4152\n",
" 4153 4154 4155 4156 4157 4158 4159 4160]], shape=(9, 64), dtype=int64)\n"
]
}
],
"source": [
"# If you compare the output below with the 'sample_sentence'\n",
"# before reshaping and tranposing, you can see that the output \n",
"# below corresponds to the blue matrix in the figure in this article.\n",
"print(sample_sentence[0][0])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment