-
-
Save YasuThompson/300da4f7b5c11ccc7a5cc6a73ff4135d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tf.Tensor(\n", | |
"[[[[ 89440 1154400 2219360 3284320 4349280 5414240\n", | |
" 6479200 7544160 8609120]\n", | |
" [ 1154400 18996576 36838752 54680928 72523104 90365280\n", | |
" 108207456 126049632 143891808]\n", | |
" [ 2219360 36838752 71458144 106077536 140696928 175316320\n", | |
" 209935712 244555104 279174496]\n", | |
" [ 3284320 54680928 106077536 157474144 208870752 260267360\n", | |
" 311663968 363060576 414457184]\n", | |
" [ 4349280 72523104 140696928 208870752 277044576 345218400\n", | |
" 413392224 481566048 549739872]\n", | |
" [ 5414240 90365280 175316320 260267360 345218400 430169440\n", | |
" 515120480 600071520 685022560]\n", | |
" [ 6479200 108207456 209935712 311663968 413392224 515120480\n", | |
" 616848736 718576992 820305248]\n", | |
" [ 7544160 126049632 244555104 363060576 481566048 600071520\n", | |
" 718576992 837082464 955587936]\n", | |
" [ 8609120 143891808 279174496 414457184 549739872 685022560\n", | |
" 820305248 955587936 1090870624]]\n", | |
"\n", | |
" [[ 617824 3779936 6942048 10104160 13266272 16428384\n", | |
" 19590496 22752608 25914720]\n", | |
" [ 3779936 23719264 43658592 63597920 83537248 103476576\n", | |
" 123415904 143355232 163294560]\n", | |
" [ 6942048 43658592 80375136 117091680 153808224 190524768\n", | |
" 227241312 263957856 300674400]\n", | |
" [ 10104160 63597920 117091680 170585440 224079200 277572960\n", | |
" 331066720 384560480 438054240]\n", | |
" [ 13266272 83537248 153808224 224079200 294350176 364621152\n", | |
" 434892128 505163104 575434080]\n", | |
" [ 16428384 103476576 190524768 277572960 364621152 451669344\n", | |
" 538717536 625765728 712813920]\n", | |
" [ 19590496 123415904 227241312 331066720 434892128 538717536\n", | |
" 642542944 746368352 850193760]\n", | |
" [ 22752608 143355232 263957856 384560480 505163104 625765728\n", | |
" 746368352 866970976 987573600]\n", | |
" [ 25914720 163294560 300674400 438054240 575434080 712813920\n", | |
" 850193760 987573600 1124953440]]\n", | |
"\n", | |
" [[ 1670496 6929760 12189024 17448288 22707552 27966816\n", | |
" 33226080 38485344 43744608]\n", | |
" [ 6929760 28966240 51002720 73039200 95075680 117112160\n", | |
" 139148640 161185120 183221600]\n", | |
" [ 12189024 51002720 89816416 128630112 167443808 206257504\n", | |
" 245071200 283884896 322698592]\n", | |
" [ 17448288 73039200 128630112 184221024 239811936 295402848\n", | |
" 350993760 406584672 462175584]\n", | |
" [ 22707552 95075680 167443808 239811936 312180064 384548192\n", | |
" 456916320 529284448 601652576]\n", | |
" [ 27966816 117112160 206257504 295402848 384548192 473693536\n", | |
" 562838880 651984224 741129568]\n", | |
" [ 33226080 139148640 245071200 350993760 456916320 562838880\n", | |
" 668761440 774684000 880606560]\n", | |
" [ 38485344 161185120 283884896 406584672 529284448 651984224\n", | |
" 774684000 897383776 1020083552]\n", | |
" [ 43744608 183221600 322698592 462175584 601652576 741129568\n", | |
" 880606560 1020083552 1159560544]]\n", | |
"\n", | |
" [[ 3247456 10603872 17960288 25316704 32673120 40029536\n", | |
" 47385952 54742368 62098784]\n", | |
" [ 10603872 34737504 58871136 83004768 107138400 131272032\n", | |
" 155405664 179539296 203672928]\n", | |
" [ 17960288 58871136 99781984 140692832 181603680 222514528\n", | |
" 263425376 304336224 345247072]\n", | |
" [ 25316704 83004768 140692832 198380896 256068960 313757024\n", | |
" 371445088 429133152 486821216]\n", | |
" [ 32673120 107138400 181603680 256068960 330534240 404999520\n", | |
" 479464800 553930080 628395360]\n", | |
" [ 40029536 131272032 222514528 313757024 404999520 496242016\n", | |
" 587484512 678727008 769969504]\n", | |
" [ 47385952 155405664 263425376 371445088 479464800 587484512\n", | |
" 695504224 803523936 911543648]\n", | |
" [ 54742368 179539296 304336224 429133152 553930080 678727008\n", | |
" 803523936 928320864 1053117792]\n", | |
" [ 62098784 203672928 345247072 486821216 628395360 769969504\n", | |
" 911543648 1053117792 1194691936]]\n", | |
"\n", | |
" [[ 5348704 14802272 24255840 33709408 43162976 52616544\n", | |
" 62070112 71523680 80977248]\n", | |
" [ 14802272 41033056 67263840 93494624 119725408 145956192\n", | |
" 172186976 198417760 224648544]\n", | |
" [ 24255840 67263840 110271840 153279840 196287840 239295840\n", | |
" 282303840 325311840 368319840]\n", | |
" [ 33709408 93494624 153279840 213065056 272850272 332635488\n", | |
" 392420704 452205920 511991136]\n", | |
" [ 43162976 119725408 196287840 272850272 349412704 425975136\n", | |
" 502537568 579100000 655662432]\n", | |
" [ 52616544 145956192 239295840 332635488 425975136 519314784\n", | |
" 612654432 705994080 799333728]\n", | |
" [ 62070112 172186976 282303840 392420704 502537568 612654432\n", | |
" 722771296 832888160 943005024]\n", | |
" [ 71523680 198417760 325311840 452205920 579100000 705994080\n", | |
" 832888160 959782240 1086676320]\n", | |
" [ 80977248 224648544 368319840 511991136 655662432 799333728\n", | |
" 943005024 1086676320 1230347616]]\n", | |
"\n", | |
" [[ 7974240 19524960 31075680 42626400 54177120 65727840\n", | |
" 77278560 88829280 100380000]\n", | |
" [ 19524960 47852896 76180832 104508768 132836704 161164640\n", | |
" 189492576 217820512 246148448]\n", | |
" [ 31075680 76180832 121285984 166391136 211496288 256601440\n", | |
" 301706592 346811744 391916896]\n", | |
" [ 42626400 104508768 166391136 228273504 290155872 352038240\n", | |
" 413920608 475802976 537685344]\n", | |
" [ 54177120 132836704 211496288 290155872 368815456 447475040\n", | |
" 526134624 604794208 683453792]\n", | |
" [ 65727840 161164640 256601440 352038240 447475040 542911840\n", | |
" 638348640 733785440 829222240]\n", | |
" [ 77278560 189492576 301706592 413920608 526134624 638348640\n", | |
" 750562656 862776672 974990688]\n", | |
" [ 88829280 217820512 346811744 475802976 604794208 733785440\n", | |
" 862776672 991767904 1120759136]\n", | |
" [ 100380000 246148448 391916896 537685344 683453792 829222240\n", | |
" 974990688 1120759136 1266527584]]\n", | |
"\n", | |
" [[ 11124064 24771936 38419808 52067680 65715552 79363424\n", | |
" 93011296 106659168 120307040]\n", | |
" [ 24771936 55197024 85622112 116047200 146472288 176897376\n", | |
" 207322464 237747552 268172640]\n", | |
" [ 38419808 85622112 132824416 180026720 227229024 274431328\n", | |
" 321633632 368835936 416038240]\n", | |
" [ 52067680 116047200 180026720 244006240 307985760 371965280\n", | |
" 435944800 499924320 563903840]\n", | |
" [ 65715552 146472288 227229024 307985760 388742496 469499232\n", | |
" 550255968 631012704 711769440]\n", | |
" [ 79363424 176897376 274431328 371965280 469499232 567033184\n", | |
" 664567136 762101088 859635040]\n", | |
" [ 93011296 207322464 321633632 435944800 550255968 664567136\n", | |
" 778878304 893189472 1007500640]\n", | |
" [ 106659168 237747552 368835936 499924320 631012704 762101088\n", | |
" 893189472 1024277856 1155366240]\n", | |
" [ 120307040 268172640 416038240 563903840 711769440 859635040\n", | |
" 1007500640 1155366240 1303231840]]\n", | |
"\n", | |
" [[ 14798176 30543200 46288224 62033248 77778272 93523296\n", | |
" 109268320 125013344 140758368]\n", | |
" [ 30543200 63065440 95587680 128109920 160632160 193154400\n", | |
" 225676640 258198880 290721120]\n", | |
" [ 46288224 95587680 144887136 194186592 243486048 292785504\n", | |
" 342084960 391384416 440683872]\n", | |
" [ 62033248 128109920 194186592 260263264 326339936 392416608\n", | |
" 458493280 524569952 590646624]\n", | |
" [ 77778272 160632160 243486048 326339936 409193824 492047712\n", | |
" 574901600 657755488 740609376]\n", | |
" [ 93523296 193154400 292785504 392416608 492047712 591678816\n", | |
" 691309920 790941024 890572128]\n", | |
" [ 109268320 225676640 342084960 458493280 574901600 691309920\n", | |
" 807718240 924126560 1040534880]\n", | |
" [ 125013344 258198880 391384416 524569952 657755488 790941024\n", | |
" 924126560 1057312096 1190497632]\n", | |
" [ 140758368 290721120 440683872 590646624 740609376 890572128\n", | |
" 1040534880 1190497632 1340460384]]]], shape=(1, 8, 9, 9), dtype=int64)\n" | |
] | |
} | |
], | |
"source": [ | |
"# With the tf.matmul() function, you can multiply two tensors \n", | |
"# along the last 2 axes, and this means you can calculate \n", | |
"# scaled dot-product independently in each head.\n", | |
"# Especially the calculation below corresponds to calculating \n", | |
"# QK^T , without rescaling or using a softmax function.\n", | |
"print(tf.matmul(sample_sentence, sample_sentence, transpose_b=True))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor: shape=(9, 9), dtype=int64, numpy=\n", | |
"array([[ 89440, 1154400, 2219360, 3284320, 4349280,\n", | |
" 5414240, 6479200, 7544160, 8609120],\n", | |
" [ 1154400, 18996576, 36838752, 54680928, 72523104,\n", | |
" 90365280, 108207456, 126049632, 143891808],\n", | |
" [ 2219360, 36838752, 71458144, 106077536, 140696928,\n", | |
" 175316320, 209935712, 244555104, 279174496],\n", | |
" [ 3284320, 54680928, 106077536, 157474144, 208870752,\n", | |
" 260267360, 311663968, 363060576, 414457184],\n", | |
" [ 4349280, 72523104, 140696928, 208870752, 277044576,\n", | |
" 345218400, 413392224, 481566048, 549739872],\n", | |
" [ 5414240, 90365280, 175316320, 260267360, 345218400,\n", | |
" 430169440, 515120480, 600071520, 685022560],\n", | |
" [ 6479200, 108207456, 209935712, 311663968, 413392224,\n", | |
" 515120480, 616848736, 718576992, 820305248],\n", | |
" [ 7544160, 126049632, 244555104, 363060576, 481566048,\n", | |
" 600071520, 718576992, 837082464, 955587936],\n", | |
" [ 8609120, 143891808, 279174496, 414457184, 549739872,\n", | |
" 685022560, 820305248, 955587936, 1090870624]])>" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# The calculation below corresponds to multiplying two blue matrices, \n", | |
"# and the result is the same as the first part of \n", | |
"# tf.matmul(sample_sentence, sample_sentence, transpose_b=True)\n", | |
"tf.matmul(sample_sentence[0][0], tf.transpose(sample_sentence[0][0]))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment