-
-
Save YasuThompson/143a6ec357792764e65e7bb80f393e69 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# In the split_heads() funtion in MultiHeadAttention() class, \n", | |
"# you reshape and transpose the input sentence. \n", | |
"# This process corresponds to the process of splitting the \n", | |
"# input sentence into 8 heads. \n", | |
"sample_sentence = tf.reshape(sample_sentence, (1, 9, 8, 64))\n", | |
"sample_sentence = tf.transpose(sample_sentence, perm=[0, 2, 1, 3])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(1, 8, 9, 64)\n" | |
] | |
} | |
], | |
"source": [ | |
"# This is the shape of 'sample_sentence' after splitting into \n", | |
"# 8 heads. The parts of (9, 64) sized matrix correspond to heads \n", | |
"# in each color. \n", | |
"print(sample_sentence.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(9, 64)\n" | |
] | |
} | |
], | |
"source": [ | |
"# The matrix below corresponds to the blue head.\n", | |
"print(sample_sentence[0][0].shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tf.Tensor(\n", | |
"[[ 1 2 3 4 5 6 7 8 9 10 11 12 13 14\n", | |
" 15 16 17 18 19 20 21 22 23 24 25 26 27 28\n", | |
" 29 30 31 32 33 34 35 36 37 38 39 40 41 42\n", | |
" 43 44 45 46 47 48 49 50 51 52 53 54 55 56\n", | |
" 57 58 59 60 61 62 63 64]\n", | |
" [ 513 514 515 516 517 518 519 520 521 522 523 524 525 526\n", | |
" 527 528 529 530 531 532 533 534 535 536 537 538 539 540\n", | |
" 541 542 543 544 545 546 547 548 549 550 551 552 553 554\n", | |
" 555 556 557 558 559 560 561 562 563 564 565 566 567 568\n", | |
" 569 570 571 572 573 574 575 576]\n", | |
" [1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038\n", | |
" 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052\n", | |
" 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066\n", | |
" 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080\n", | |
" 1081 1082 1083 1084 1085 1086 1087 1088]\n", | |
" [1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550\n", | |
" 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564\n", | |
" 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578\n", | |
" 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592\n", | |
" 1593 1594 1595 1596 1597 1598 1599 1600]\n", | |
" [2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062\n", | |
" 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076\n", | |
" 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090\n", | |
" 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104\n", | |
" 2105 2106 2107 2108 2109 2110 2111 2112]\n", | |
" [2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574\n", | |
" 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588\n", | |
" 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602\n", | |
" 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616\n", | |
" 2617 2618 2619 2620 2621 2622 2623 2624]\n", | |
" [3073 3074 3075 3076 3077 3078 3079 3080 3081 3082 3083 3084 3085 3086\n", | |
" 3087 3088 3089 3090 3091 3092 3093 3094 3095 3096 3097 3098 3099 3100\n", | |
" 3101 3102 3103 3104 3105 3106 3107 3108 3109 3110 3111 3112 3113 3114\n", | |
" 3115 3116 3117 3118 3119 3120 3121 3122 3123 3124 3125 3126 3127 3128\n", | |
" 3129 3130 3131 3132 3133 3134 3135 3136]\n", | |
" [3585 3586 3587 3588 3589 3590 3591 3592 3593 3594 3595 3596 3597 3598\n", | |
" 3599 3600 3601 3602 3603 3604 3605 3606 3607 3608 3609 3610 3611 3612\n", | |
" 3613 3614 3615 3616 3617 3618 3619 3620 3621 3622 3623 3624 3625 3626\n", | |
" 3627 3628 3629 3630 3631 3632 3633 3634 3635 3636 3637 3638 3639 3640\n", | |
" 3641 3642 3643 3644 3645 3646 3647 3648]\n", | |
" [4097 4098 4099 4100 4101 4102 4103 4104 4105 4106 4107 4108 4109 4110\n", | |
" 4111 4112 4113 4114 4115 4116 4117 4118 4119 4120 4121 4122 4123 4124\n", | |
" 4125 4126 4127 4128 4129 4130 4131 4132 4133 4134 4135 4136 4137 4138\n", | |
" 4139 4140 4141 4142 4143 4144 4145 4146 4147 4148 4149 4150 4151 4152\n", | |
" 4153 4154 4155 4156 4157 4158 4159 4160]], shape=(9, 64), dtype=int64)\n" | |
] | |
} | |
], | |
"source": [ | |
"# If you compare the output below with the 'sample_sentence'\n", | |
"# before reshaping and tranposing, you can see that the output \n", | |
"# below corresponds to the blue matrix in the figure in this article.\n", | |
"print(sample_sentence[0][0])" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment