Skip to content

Instantly share code, notes, and snippets.

@YasuThompson
Created February 18, 2021 10:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YasuThompson/300da4f7b5c11ccc7a5cc6a73ff4135d to your computer and use it in GitHub Desktop.
Save YasuThompson/300da4f7b5c11ccc7a5cc6a73ff4135d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[[[ 89440 1154400 2219360 3284320 4349280 5414240\n",
" 6479200 7544160 8609120]\n",
" [ 1154400 18996576 36838752 54680928 72523104 90365280\n",
" 108207456 126049632 143891808]\n",
" [ 2219360 36838752 71458144 106077536 140696928 175316320\n",
" 209935712 244555104 279174496]\n",
" [ 3284320 54680928 106077536 157474144 208870752 260267360\n",
" 311663968 363060576 414457184]\n",
" [ 4349280 72523104 140696928 208870752 277044576 345218400\n",
" 413392224 481566048 549739872]\n",
" [ 5414240 90365280 175316320 260267360 345218400 430169440\n",
" 515120480 600071520 685022560]\n",
" [ 6479200 108207456 209935712 311663968 413392224 515120480\n",
" 616848736 718576992 820305248]\n",
" [ 7544160 126049632 244555104 363060576 481566048 600071520\n",
" 718576992 837082464 955587936]\n",
" [ 8609120 143891808 279174496 414457184 549739872 685022560\n",
" 820305248 955587936 1090870624]]\n",
"\n",
" [[ 617824 3779936 6942048 10104160 13266272 16428384\n",
" 19590496 22752608 25914720]\n",
" [ 3779936 23719264 43658592 63597920 83537248 103476576\n",
" 123415904 143355232 163294560]\n",
" [ 6942048 43658592 80375136 117091680 153808224 190524768\n",
" 227241312 263957856 300674400]\n",
" [ 10104160 63597920 117091680 170585440 224079200 277572960\n",
" 331066720 384560480 438054240]\n",
" [ 13266272 83537248 153808224 224079200 294350176 364621152\n",
" 434892128 505163104 575434080]\n",
" [ 16428384 103476576 190524768 277572960 364621152 451669344\n",
" 538717536 625765728 712813920]\n",
" [ 19590496 123415904 227241312 331066720 434892128 538717536\n",
" 642542944 746368352 850193760]\n",
" [ 22752608 143355232 263957856 384560480 505163104 625765728\n",
" 746368352 866970976 987573600]\n",
" [ 25914720 163294560 300674400 438054240 575434080 712813920\n",
" 850193760 987573600 1124953440]]\n",
"\n",
" [[ 1670496 6929760 12189024 17448288 22707552 27966816\n",
" 33226080 38485344 43744608]\n",
" [ 6929760 28966240 51002720 73039200 95075680 117112160\n",
" 139148640 161185120 183221600]\n",
" [ 12189024 51002720 89816416 128630112 167443808 206257504\n",
" 245071200 283884896 322698592]\n",
" [ 17448288 73039200 128630112 184221024 239811936 295402848\n",
" 350993760 406584672 462175584]\n",
" [ 22707552 95075680 167443808 239811936 312180064 384548192\n",
" 456916320 529284448 601652576]\n",
" [ 27966816 117112160 206257504 295402848 384548192 473693536\n",
" 562838880 651984224 741129568]\n",
" [ 33226080 139148640 245071200 350993760 456916320 562838880\n",
" 668761440 774684000 880606560]\n",
" [ 38485344 161185120 283884896 406584672 529284448 651984224\n",
" 774684000 897383776 1020083552]\n",
" [ 43744608 183221600 322698592 462175584 601652576 741129568\n",
" 880606560 1020083552 1159560544]]\n",
"\n",
" [[ 3247456 10603872 17960288 25316704 32673120 40029536\n",
" 47385952 54742368 62098784]\n",
" [ 10603872 34737504 58871136 83004768 107138400 131272032\n",
" 155405664 179539296 203672928]\n",
" [ 17960288 58871136 99781984 140692832 181603680 222514528\n",
" 263425376 304336224 345247072]\n",
" [ 25316704 83004768 140692832 198380896 256068960 313757024\n",
" 371445088 429133152 486821216]\n",
" [ 32673120 107138400 181603680 256068960 330534240 404999520\n",
" 479464800 553930080 628395360]\n",
" [ 40029536 131272032 222514528 313757024 404999520 496242016\n",
" 587484512 678727008 769969504]\n",
" [ 47385952 155405664 263425376 371445088 479464800 587484512\n",
" 695504224 803523936 911543648]\n",
" [ 54742368 179539296 304336224 429133152 553930080 678727008\n",
" 803523936 928320864 1053117792]\n",
" [ 62098784 203672928 345247072 486821216 628395360 769969504\n",
" 911543648 1053117792 1194691936]]\n",
"\n",
" [[ 5348704 14802272 24255840 33709408 43162976 52616544\n",
" 62070112 71523680 80977248]\n",
" [ 14802272 41033056 67263840 93494624 119725408 145956192\n",
" 172186976 198417760 224648544]\n",
" [ 24255840 67263840 110271840 153279840 196287840 239295840\n",
" 282303840 325311840 368319840]\n",
" [ 33709408 93494624 153279840 213065056 272850272 332635488\n",
" 392420704 452205920 511991136]\n",
" [ 43162976 119725408 196287840 272850272 349412704 425975136\n",
" 502537568 579100000 655662432]\n",
" [ 52616544 145956192 239295840 332635488 425975136 519314784\n",
" 612654432 705994080 799333728]\n",
" [ 62070112 172186976 282303840 392420704 502537568 612654432\n",
" 722771296 832888160 943005024]\n",
" [ 71523680 198417760 325311840 452205920 579100000 705994080\n",
" 832888160 959782240 1086676320]\n",
" [ 80977248 224648544 368319840 511991136 655662432 799333728\n",
" 943005024 1086676320 1230347616]]\n",
"\n",
" [[ 7974240 19524960 31075680 42626400 54177120 65727840\n",
" 77278560 88829280 100380000]\n",
" [ 19524960 47852896 76180832 104508768 132836704 161164640\n",
" 189492576 217820512 246148448]\n",
" [ 31075680 76180832 121285984 166391136 211496288 256601440\n",
" 301706592 346811744 391916896]\n",
" [ 42626400 104508768 166391136 228273504 290155872 352038240\n",
" 413920608 475802976 537685344]\n",
" [ 54177120 132836704 211496288 290155872 368815456 447475040\n",
" 526134624 604794208 683453792]\n",
" [ 65727840 161164640 256601440 352038240 447475040 542911840\n",
" 638348640 733785440 829222240]\n",
" [ 77278560 189492576 301706592 413920608 526134624 638348640\n",
" 750562656 862776672 974990688]\n",
" [ 88829280 217820512 346811744 475802976 604794208 733785440\n",
" 862776672 991767904 1120759136]\n",
" [ 100380000 246148448 391916896 537685344 683453792 829222240\n",
" 974990688 1120759136 1266527584]]\n",
"\n",
" [[ 11124064 24771936 38419808 52067680 65715552 79363424\n",
" 93011296 106659168 120307040]\n",
" [ 24771936 55197024 85622112 116047200 146472288 176897376\n",
" 207322464 237747552 268172640]\n",
" [ 38419808 85622112 132824416 180026720 227229024 274431328\n",
" 321633632 368835936 416038240]\n",
" [ 52067680 116047200 180026720 244006240 307985760 371965280\n",
" 435944800 499924320 563903840]\n",
" [ 65715552 146472288 227229024 307985760 388742496 469499232\n",
" 550255968 631012704 711769440]\n",
" [ 79363424 176897376 274431328 371965280 469499232 567033184\n",
" 664567136 762101088 859635040]\n",
" [ 93011296 207322464 321633632 435944800 550255968 664567136\n",
" 778878304 893189472 1007500640]\n",
" [ 106659168 237747552 368835936 499924320 631012704 762101088\n",
" 893189472 1024277856 1155366240]\n",
" [ 120307040 268172640 416038240 563903840 711769440 859635040\n",
" 1007500640 1155366240 1303231840]]\n",
"\n",
" [[ 14798176 30543200 46288224 62033248 77778272 93523296\n",
" 109268320 125013344 140758368]\n",
" [ 30543200 63065440 95587680 128109920 160632160 193154400\n",
" 225676640 258198880 290721120]\n",
" [ 46288224 95587680 144887136 194186592 243486048 292785504\n",
" 342084960 391384416 440683872]\n",
" [ 62033248 128109920 194186592 260263264 326339936 392416608\n",
" 458493280 524569952 590646624]\n",
" [ 77778272 160632160 243486048 326339936 409193824 492047712\n",
" 574901600 657755488 740609376]\n",
" [ 93523296 193154400 292785504 392416608 492047712 591678816\n",
" 691309920 790941024 890572128]\n",
" [ 109268320 225676640 342084960 458493280 574901600 691309920\n",
" 807718240 924126560 1040534880]\n",
" [ 125013344 258198880 391384416 524569952 657755488 790941024\n",
" 924126560 1057312096 1190497632]\n",
" [ 140758368 290721120 440683872 590646624 740609376 890572128\n",
" 1040534880 1190497632 1340460384]]]], shape=(1, 8, 9, 9), dtype=int64)\n"
]
}
],
"source": [
"# With the tf.matmul() function, you can multiply two tensors \n",
"# along the last 2 axes, and this means you can calculate \n",
"# scaled dot-product independently in each head.\n",
"# Especially the calculation below corresponds to calculating \n",
"# QK^T , without rescaling or using a softmax function.\n",
"print(tf.matmul(sample_sentence, sample_sentence, transpose_b=True))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(9, 9), dtype=int64, numpy=\n",
"array([[ 89440, 1154400, 2219360, 3284320, 4349280,\n",
" 5414240, 6479200, 7544160, 8609120],\n",
" [ 1154400, 18996576, 36838752, 54680928, 72523104,\n",
" 90365280, 108207456, 126049632, 143891808],\n",
" [ 2219360, 36838752, 71458144, 106077536, 140696928,\n",
" 175316320, 209935712, 244555104, 279174496],\n",
" [ 3284320, 54680928, 106077536, 157474144, 208870752,\n",
" 260267360, 311663968, 363060576, 414457184],\n",
" [ 4349280, 72523104, 140696928, 208870752, 277044576,\n",
" 345218400, 413392224, 481566048, 549739872],\n",
" [ 5414240, 90365280, 175316320, 260267360, 345218400,\n",
" 430169440, 515120480, 600071520, 685022560],\n",
" [ 6479200, 108207456, 209935712, 311663968, 413392224,\n",
" 515120480, 616848736, 718576992, 820305248],\n",
" [ 7544160, 126049632, 244555104, 363060576, 481566048,\n",
" 600071520, 718576992, 837082464, 955587936],\n",
" [ 8609120, 143891808, 279174496, 414457184, 549739872,\n",
" 685022560, 820305248, 955587936, 1090870624]])>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# The calculation below corresponds to multiplying two blue matrices, \n",
"# and the result is the same as the first part of \n",
"# tf.matmul(sample_sentence, sample_sentence, transpose_b=True)\n",
"tf.matmul(sample_sentence[0][0], tf.transpose(sample_sentence[0][0]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment