Last active
April 21, 2019 11:08
-
-
Save hsm207/a16000ddd75097980275d401c3d27e73 to your computer and use it in GitHub Desktop.
Code to accompany my blog post at https://bit.ly/2PmRjiC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "dehkToEyBgaE", | |
"toc": true | |
}, | |
"source": [ | |
"<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n", | |
"<div class=\"toc\" style=\"margin-top: 1em;\"><ul class=\"toc-item\"><li><span><a href=\"#Libraries\" data-toc-modified-id=\"Libraries-1\"><span class=\"toc-item-num\">1 </span>Libraries</a></span></li><li><span><a href=\"#RPR-Implementations\" data-toc-modified-id=\"RPR-Implementations-2\"><span class=\"toc-item-num\">2 </span>RPR Implementations</a></span></li><li><span><a href=\"#Test\" data-toc-modified-id=\"Test-3\"><span class=\"toc-item-num\">3 </span>Test</a></span></li></ul></div>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "Wnx_-nGABgaF" | |
}, | |
"source": [ | |
"# Libraries" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 306 | |
}, | |
"colab_type": "code", | |
"id": "YNhmpbh4EtNv", | |
"outputId": "fde8fa47-f866-4809-e232-118427f2a82d" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Sun Apr 21 09:48:47 2019 \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| NVIDIA-SMI 418.56 Driver Version: 410.79 CUDA Version: 10.0 |\n", | |
"|-------------------------------+----------------------+----------------------+\n", | |
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", | |
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", | |
"|===============================+======================+======================|\n", | |
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", | |
"| N/A 68C P8 17W / 70W | 0MiB / 15079MiB | 0% Default |\n", | |
"+-------------------------------+----------------------+----------------------+\n", | |
" \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| Processes: GPU Memory |\n", | |
"| GPU PID Type Process name Usage |\n", | |
"|=============================================================================|\n", | |
"| No running processes found |\n", | |
"+-----------------------------------------------------------------------------+\n" | |
] | |
} | |
], | |
"source": [ | |
"!nvidia-smi" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 340 | |
}, | |
"colab_type": "code", | |
"id": "Ao7z75xuDN8W", | |
"outputId": "51a5bd6f-5a4b-4231-c5ba-2d79d1e4e7dd" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: tensorflow-gpu==2.0.0-alpha0 in /usr/local/lib/python3.6/dist-packages (2.0.0a0)\n", | |
"Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.0.9)\n", | |
"Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.11.0)\n", | |
"Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (3.7.1)\n", | |
"Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (0.7.1)\n", | |
"Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.1.0)\n", | |
"Requirement already satisfied: tf-estimator-nightly<1.14.0.dev2019030116,>=1.14.0.dev2019030115 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.14.0.dev2019030115)\n", | |
"Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.0.7)\n", | |
"Requirement already satisfied: google-pasta>=0.1.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (0.1.5)\n", | |
"Requirement already satisfied: tb-nightly<1.14.0a20190302,>=1.14.0a20190301 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.14.0a20190301)\n", | |
"Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.15.0)\n", | |
"Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (0.33.1)\n", | |
"Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (0.7.1)\n", | |
"Requirement already satisfied: numpy<2.0,>=1.14.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.16.2)\n", | |
"Requirement already satisfied: gast>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (0.2.2)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.6.1->tensorflow-gpu==2.0.0-alpha0) (40.9.0)\n", | |
"Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.6->tensorflow-gpu==2.0.0-alpha0) (2.8.0)\n", | |
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<1.14.0a20190302,>=1.14.0a20190301->tensorflow-gpu==2.0.0-alpha0) (3.1)\n", | |
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<1.14.0a20190302,>=1.14.0a20190301->tensorflow-gpu==2.0.0-alpha0) (0.15.2)\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install tensorflow-gpu==2.0.0-alpha0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 0, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "31TscOKbBgaO" | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"colab_type": "code", | |
"id": "gtzVk7yZBgaS", | |
"outputId": "b0f18341-ad6d-4ee8-c0c3-1e5c610c9ed0" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'2.0.0-alpha0'" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": { | |
"tags": [] | |
}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tf.__version__" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "jcVaDB2QBgae" | |
}, | |
"source": [ | |
"# RPR Implementations" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "89paiICNBgae" | |
}, | |
"source": [ | |
"Original implementation:\n", | |
"\n", | |
"Note: This function is from [this](https://github.com/tensorflow/tensor2tensor/blob/9e0a894034d8090892c238df1bd9bd3180c2b9a3/tensor2tensor/layers/common_attention.py#L1556-L1587) repo." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 0, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "dGY4eCPtBgaf" | |
}, | |
"outputs": [], | |
"source": [ | |
"def _relative_attention_inner(x, y, z, transpose):\n", | |
" \"\"\"Relative position-aware dot-product attention inner calculation.\n", | |
" This batches matrix multiply calculations to avoid unnecessary broadcasting.\n", | |
" Args:\n", | |
" x: Tensor with shape [batch_size, heads, length or 1, length or depth].\n", | |
" y: Tensor with shape [batch_size, heads, length or 1, depth].\n", | |
" z: Tensor with shape [length or 1, length, depth].\n", | |
" transpose: Whether to transpose inner matrices of y and z. Should be true if\n", | |
" last dimension of x is depth, not length.\n", | |
" Returns:\n", | |
" A Tensor with shape [batch_size, heads, length, length or depth].\n", | |
" \"\"\"\n", | |
" batch_size = tf.shape(x)[0]\n", | |
" heads = x.get_shape().as_list()[1]\n", | |
" length = tf.shape(x)[2]\n", | |
"\n", | |
" # xy_matmul is [batch_size, heads, length or 1, length or depth]\n", | |
" xy_matmul = tf.matmul(x, y, transpose_b=transpose)\n", | |
" # x_t is [length or 1, batch_size, heads, length or depth]\n", | |
" x_t = tf.transpose(x, [2, 0, 1, 3])\n", | |
" # x_t_r is [length or 1, batch_size * heads, length or depth]\n", | |
" x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])\n", | |
" # x_tz_matmul is [length or 1, batch_size * heads, length or depth]\n", | |
" x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)\n", | |
" # x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]\n", | |
" x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])\n", | |
" # x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]\n", | |
" x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])\n", | |
" return xy_matmul + x_tz_matmul_r_t" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "SAts-V8bBgai" | |
}, | |
"source": [ | |
"Implementation with einsum (only works for the \"key\" part):" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 0, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "0ytrqbhmBgaj" | |
}, | |
"outputs": [], | |
"source": [ | |
"def _relative_attention_inner_with_einsum(x, y, z):\n", | |
" # let:\n", | |
" # b: batch size axis\n", | |
" # h: heads axis\n", | |
" # i: row axis\n", | |
" # j: column axis\n", | |
" # of the output in the einsum expression\n", | |
" first_term = tf.einsum('bhik,bhjk->bhij', x, y)\n", | |
" second_term = tf.einsum('bhik,ijk->bhij', x, z)\n", | |
" return first_term + second_term" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "Z-nS6DdHBgan" | |
}, | |
"source": [ | |
"# Test" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "QTJQMUgKBgat" | |
}, | |
"source": [ | |
"Define some parameters to control the size of the input tensors:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 0, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "XelafKj8Bgav" | |
}, | |
"outputs": [], | |
"source": [ | |
"batch_size = 32\n", | |
"heads = 8\n", | |
"length = 128\n", | |
"\n", | |
"# embedding size\n", | |
"depth = 768" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "9zgJE2dUBga-" | |
}, | |
"source": [ | |
"Define the query matrix, key matrix and relative position embeddings matrix:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 0, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "1YvG1kwqBgbA" | |
}, | |
"outputs": [], | |
"source": [ | |
"Q = tf.range(1, batch_size * heads * length * depth + 1)\n", | |
"Q = tf.cast(Q, tf.float32)\n", | |
"Q = tf.reshape(Q, (batch_size, heads, length, depth))\n", | |
"#print(Q)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 0, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "a8hAhKdWBgbH" | |
}, | |
"outputs": [], | |
"source": [ | |
"K = tf.range(1, batch_size * heads * length * depth + 1)\n", | |
"K = tf.cast(K, tf.float32)\n", | |
"K = tf.reshape(K, (batch_size, heads, length, depth))\n", | |
"#print(K)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 0, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "jSAKONCaBgbQ" | |
}, | |
"outputs": [], | |
"source": [ | |
"rpr_embeddings_for_keys = tf.range(1, length * length * depth + 1)\n", | |
"rpr_embeddings_for_keys = tf.cast(rpr_embeddings_for_keys, tf.float32)\n", | |
"rpr_embeddings_for_keys = tf.reshape(rpr_embeddings_for_keys, (length, length, depth))\n", | |
"#print(rpr_embeddings_for_keys)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "U1kHLl0KBgbU" | |
}, | |
"source": [ | |
"Given this equation:\n", | |
"\n", | |
"![](https://cdn-images-1.medium.com/max/1000/1*VB1i8gI67cPHQ7bkmVVk2g.png)\n", | |
"\n", | |
"We will compute the numerator (for a batch of inputs) using the original implementation and the einsum implementation." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 0, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "y8qCLl-IBgbV" | |
}, | |
"outputs": [], | |
"source": [ | |
"ori_result = _relative_attention_inner(x=Q,\n", | |
" y=K,\n", | |
" z=rpr_embeddings_for_keys,\n", | |
" transpose=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 0, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "kutokFKUBgbY" | |
}, | |
"outputs": [], | |
"source": [ | |
"einsum_result = _relative_attention_inner_with_einsum(x=Q,\n", | |
" y=K,\n", | |
" z=rpr_embeddings_for_keys)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "YAPTpBkABgbf" | |
}, | |
"source": [ | |
"Check that both results are equal:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 0, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "xKR1VXgnBgbg" | |
}, | |
"outputs": [], | |
"source": [ | |
"is_equal = tf.equal(ori_result, einsum_result)\n", | |
"assert tf.reduce_all(is_equal)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "BGqxN1vnBgbo" | |
}, | |
"source": [ | |
"Time the implementations:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"colab_type": "code", | |
"id": "1HxsM2hjBgbu", | |
"outputId": "9be04af8-7c8d-42f9-c23f-65b125b5e79f" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1000 loops, best of 3: 5.81 ms per loop\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -n 1000\n", | |
"_relative_attention_inner(x=Q,\n", | |
" y=K,\n", | |
" z=rpr_embeddings_for_keys,\n", | |
" transpose=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"colab_type": "code", | |
"id": "mzpDlAKRBgb0", | |
"outputId": "8e711742-2095-4d34-a124-4982222dea0d" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1000 loops, best of 3: 7.09 ms per loop\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -n 1000\n", | |
"_relative_attention_inner_with_einsum(x=Q,\n", | |
" y=K,\n", | |
" z=rpr_embeddings_for_keys)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "VXM-xxJGBgb-" | |
}, | |
"source": [ | |
"The einsum implementation is slower but what if we time just the time in takes to compute the second term?" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "01W5esTbBgcB" | |
}, | |
"source": [ | |
"Some setup:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 0, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "w2xv6ncyBgcC" | |
}, | |
"outputs": [], | |
"source": [ | |
"x = Q\n", | |
"y = K\n", | |
"z = rpr_embeddings_for_keys\n", | |
"transpose = True" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "5aSZH7ExBgcH" | |
}, | |
"source": [ | |
"Original implementation:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"colab_type": "code", | |
"id": "GuSbtyZdBgcJ", | |
"outputId": "34db91e0-95df-4a62-f112-29d70e067431" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1000 loops, best of 3: 4.01 ms per loop\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -n 1000\n", | |
"\n", | |
"x_t = tf.transpose(x, [2, 0, 1, 3])\n", | |
"# x_t_r is [length or 1, batch_size * heads, length or depth]\n", | |
"x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])\n", | |
"# x_tz_matmul is [length or 1, batch_size * heads, length or depth]\n", | |
"x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)\n", | |
"# x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]\n", | |
"x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])\n", | |
"# x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]\n", | |
"x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "K0ttkQr9BgcP" | |
}, | |
"source": [ | |
"Einsum implementation:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"colab_type": "code", | |
"id": "BpbGajDdBgcP", | |
"outputId": "0a563514-7e3f-464f-ac84-6f6908ea55e4" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1000 loops, best of 3: 3.94 ms per loop\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -n 1000\n", | |
"second_term = tf.einsum('bhik,ijk->bhij', x, z)" | |
] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"collapsed_sections": [], | |
"name": "rpr_with_einsum_gpu.ipynb", | |
"provenance": [], | |
"version": "0.3.2" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
}, | |
"toc": { | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"toc_cell": true, | |
"toc_position": {}, | |
"toc_section_display": "block", | |
"toc_window_display": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment