Skip to content

Instantly share code, notes, and snippets.

@hsm207
Last active April 21, 2019 11:08
Show Gist options
  • Save hsm207/a16000ddd75097980275d401c3d27e73 to your computer and use it in GitHub Desktop.
Save hsm207/a16000ddd75097980275d401c3d27e73 to your computer and use it in GitHub Desktop.
Code to accompany my blog post at https://bit.ly/2PmRjiC
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dehkToEyBgaE",
"toc": true
},
"source": [
"<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n",
"<div class=\"toc\" style=\"margin-top: 1em;\"><ul class=\"toc-item\"><li><span><a href=\"#Libraries\" data-toc-modified-id=\"Libraries-1\"><span class=\"toc-item-num\">1&nbsp;&nbsp;</span>Libraries</a></span></li><li><span><a href=\"#RPR-Implementations\" data-toc-modified-id=\"RPR-Implementations-2\"><span class=\"toc-item-num\">2&nbsp;&nbsp;</span>RPR Implementations</a></span></li><li><span><a href=\"#Test\" data-toc-modified-id=\"Test-3\"><span class=\"toc-item-num\">3&nbsp;&nbsp;</span>Test</a></span></li></ul></div>"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Wnx_-nGABgaF"
},
"source": [
"# Libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 306
},
"colab_type": "code",
"id": "YNhmpbh4EtNv",
"outputId": "fde8fa47-f866-4809-e232-118427f2a82d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sun Apr 21 09:48:47 2019 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 418.56 Driver Version: 410.79 CUDA Version: 10.0 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 68C P8 17W / 70W | 0MiB / 15079MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: GPU Memory |\n",
"| GPU PID Type Process name Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 340
},
"colab_type": "code",
"id": "Ao7z75xuDN8W",
"outputId": "51a5bd6f-5a4b-4231-c5ba-2d79d1e4e7dd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: tensorflow-gpu==2.0.0-alpha0 in /usr/local/lib/python3.6/dist-packages (2.0.0a0)\n",
"Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.0.9)\n",
"Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.11.0)\n",
"Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (3.7.1)\n",
"Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (0.7.1)\n",
"Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.1.0)\n",
"Requirement already satisfied: tf-estimator-nightly<1.14.0.dev2019030116,>=1.14.0.dev2019030115 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.14.0.dev2019030115)\n",
"Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.0.7)\n",
"Requirement already satisfied: google-pasta>=0.1.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (0.1.5)\n",
"Requirement already satisfied: tb-nightly<1.14.0a20190302,>=1.14.0a20190301 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.14.0a20190301)\n",
"Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.15.0)\n",
"Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (0.33.1)\n",
"Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (0.7.1)\n",
"Requirement already satisfied: numpy<2.0,>=1.14.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (1.16.2)\n",
"Requirement already satisfied: gast>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-alpha0) (0.2.2)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.6.1->tensorflow-gpu==2.0.0-alpha0) (40.9.0)\n",
"Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.6->tensorflow-gpu==2.0.0-alpha0) (2.8.0)\n",
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<1.14.0a20190302,>=1.14.0a20190301->tensorflow-gpu==2.0.0-alpha0) (3.1)\n",
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<1.14.0a20190302,>=1.14.0a20190301->tensorflow-gpu==2.0.0-alpha0) (0.15.2)\n"
]
}
],
"source": [
"!pip install tensorflow-gpu==2.0.0-alpha0"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "31TscOKbBgaO"
},
"outputs": [],
"source": [
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "gtzVk7yZBgaS",
"outputId": "b0f18341-ad6d-4ee8-c0c3-1e5c610c9ed0"
},
"outputs": [
{
"data": {
"text/plain": [
"'2.0.0-alpha0'"
]
},
"execution_count": 4,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"tf.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "jcVaDB2QBgae"
},
"source": [
"# RPR Implementations"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "89paiICNBgae"
},
"source": [
"Original implementation:\n",
"\n",
"Note: This function is from [this](https://github.com/tensorflow/tensor2tensor/blob/9e0a894034d8090892c238df1bd9bd3180c2b9a3/tensor2tensor/layers/common_attention.py#L1556-L1587) repo."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "dGY4eCPtBgaf"
},
"outputs": [],
"source": [
"def _relative_attention_inner(x, y, z, transpose):\n",
" \"\"\"Relative position-aware dot-product attention inner calculation.\n",
" This batches matrix multiply calculations to avoid unnecessary broadcasting.\n",
" Args:\n",
" x: Tensor with shape [batch_size, heads, length or 1, length or depth].\n",
" y: Tensor with shape [batch_size, heads, length or 1, depth].\n",
" z: Tensor with shape [length or 1, length, depth].\n",
" transpose: Whether to transpose inner matrices of y and z. Should be true if\n",
" last dimension of x is depth, not length.\n",
" Returns:\n",
" A Tensor with shape [batch_size, heads, length, length or depth].\n",
" \"\"\"\n",
" batch_size = tf.shape(x)[0]\n",
" heads = x.get_shape().as_list()[1]\n",
" length = tf.shape(x)[2]\n",
"\n",
" # xy_matmul is [batch_size, heads, length or 1, length or depth]\n",
" xy_matmul = tf.matmul(x, y, transpose_b=transpose)\n",
" # x_t is [length or 1, batch_size, heads, length or depth]\n",
" x_t = tf.transpose(x, [2, 0, 1, 3])\n",
" # x_t_r is [length or 1, batch_size * heads, length or depth]\n",
" x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])\n",
" # x_tz_matmul is [length or 1, batch_size * heads, length or depth]\n",
" x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)\n",
" # x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]\n",
" x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])\n",
" # x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]\n",
" x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])\n",
" return xy_matmul + x_tz_matmul_r_t"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "SAts-V8bBgai"
},
"source": [
"Implementation with einsum (only works for the \"key\" part):"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "0ytrqbhmBgaj"
},
"outputs": [],
"source": [
"def _relative_attention_inner_with_einsum(x, y, z):\n",
" # let:\n",
" # b: batch size axis\n",
" # h: heads axis\n",
" # i: row axis\n",
" # j: column axis\n",
" # of the output in the einsum expression\n",
" first_term = tf.einsum('bhik,bhjk->bhij', x, y)\n",
" second_term = tf.einsum('bhik,ijk->bhij', x, z)\n",
" return first_term + second_term"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Z-nS6DdHBgan"
},
"source": [
"# Test"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "QTJQMUgKBgat"
},
"source": [
"Define some parameters to control the size of the input tensors:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XelafKj8Bgav"
},
"outputs": [],
"source": [
"batch_size = 32\n",
"heads = 8\n",
"length = 128\n",
"\n",
"# embedding size\n",
"depth = 768"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9zgJE2dUBga-"
},
"source": [
"Define the query matrix, key matrix and relative position embeddings matrix:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "1YvG1kwqBgbA"
},
"outputs": [],
"source": [
"Q = tf.range(1, batch_size * heads * length * depth + 1)\n",
"Q = tf.cast(Q, tf.float32)\n",
"Q = tf.reshape(Q, (batch_size, heads, length, depth))\n",
"#print(Q)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "a8hAhKdWBgbH"
},
"outputs": [],
"source": [
"K = tf.range(1, batch_size * heads * length * depth + 1)\n",
"K = tf.cast(K, tf.float32)\n",
"K = tf.reshape(K, (batch_size, heads, length, depth))\n",
"#print(K)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jSAKONCaBgbQ"
},
"outputs": [],
"source": [
"rpr_embeddings_for_keys = tf.range(1, length * length * depth + 1)\n",
"rpr_embeddings_for_keys = tf.cast(rpr_embeddings_for_keys, tf.float32)\n",
"rpr_embeddings_for_keys = tf.reshape(rpr_embeddings_for_keys, (length, length, depth))\n",
"#print(rpr_embeddings_for_keys)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "U1kHLl0KBgbU"
},
"source": [
"Given this equation:\n",
"\n",
"![](https://cdn-images-1.medium.com/max/1000/1*VB1i8gI67cPHQ7bkmVVk2g.png)\n",
"\n",
"We will compute the numerator (for a batch of inputs) using the original implementation and the einsum implementation."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "y8qCLl-IBgbV"
},
"outputs": [],
"source": [
"ori_result = _relative_attention_inner(x=Q,\n",
" y=K,\n",
" z=rpr_embeddings_for_keys,\n",
" transpose=True)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "kutokFKUBgbY"
},
"outputs": [],
"source": [
"einsum_result = _relative_attention_inner_with_einsum(x=Q,\n",
" y=K,\n",
" z=rpr_embeddings_for_keys)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "YAPTpBkABgbf"
},
"source": [
"Check that both results are equal:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "xKR1VXgnBgbg"
},
"outputs": [],
"source": [
"is_equal = tf.equal(ori_result, einsum_result)\n",
"assert tf.reduce_all(is_equal)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "BGqxN1vnBgbo"
},
"source": [
"Time the implementations:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "1HxsM2hjBgbu",
"outputId": "9be04af8-7c8d-42f9-c23f-65b125b5e79f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1000 loops, best of 3: 5.81 ms per loop\n"
]
}
],
"source": [
"%%timeit -n 1000\n",
"_relative_attention_inner(x=Q,\n",
" y=K,\n",
" z=rpr_embeddings_for_keys,\n",
" transpose=True)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "mzpDlAKRBgb0",
"outputId": "8e711742-2095-4d34-a124-4982222dea0d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1000 loops, best of 3: 7.09 ms per loop\n"
]
}
],
"source": [
"%%timeit -n 1000\n",
"_relative_attention_inner_with_einsum(x=Q,\n",
" y=K,\n",
" z=rpr_embeddings_for_keys)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "VXM-xxJGBgb-"
},
"source": [
"The einsum implementation is slower but what if we time just the time in takes to compute the second term?"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "01W5esTbBgcB"
},
"source": [
"Some setup:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "w2xv6ncyBgcC"
},
"outputs": [],
"source": [
"x = Q\n",
"y = K\n",
"z = rpr_embeddings_for_keys\n",
"transpose = True"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "5aSZH7ExBgcH"
},
"source": [
"Original implementation:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "GuSbtyZdBgcJ",
"outputId": "34db91e0-95df-4a62-f112-29d70e067431"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1000 loops, best of 3: 4.01 ms per loop\n"
]
}
],
"source": [
"%%timeit -n 1000\n",
"\n",
"x_t = tf.transpose(x, [2, 0, 1, 3])\n",
"# x_t_r is [length or 1, batch_size * heads, length or depth]\n",
"x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])\n",
"# x_tz_matmul is [length or 1, batch_size * heads, length or depth]\n",
"x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)\n",
"# x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]\n",
"x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])\n",
"# x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]\n",
"x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "K0ttkQr9BgcP"
},
"source": [
"Einsum implementation:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "BpbGajDdBgcP",
"outputId": "0a563514-7e3f-464f-ac84-6f6908ea55e4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1000 loops, best of 3: 3.94 ms per loop\n"
]
}
],
"source": [
"%%timeit -n 1000\n",
"second_term = tf.einsum('bhik,ijk->bhij', x, z)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "rpr_with_einsum_gpu.ipynb",
"provenance": [],
"version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"toc_cell": true,
"toc_position": {},
"toc_section_display": "block",
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment