Skip to content

Instantly share code, notes, and snippets.

@powderluv
Forked from carlthome/tfcompile.ipynb
Created January 15, 2020 23:05
Show Gist options
  • Save powderluv/09e402b8cf52d27e8936adadd395c4fc to your computer and use it in GitHub Desktop.
Save powderluv/09e402b8cf52d27e8936adadd395c4fc to your computer and use it in GitHub Desktop.
Example of how to use XLA AOT via tfcompile to build a Keras model into a shared library.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Deploying a TensorFlow graph via XLA AOT compilation\n",
"Many machine learning models are deployed as cloud services where you can accommodate a full-blown runtime, but managing servers and requiring internet connectivity for your app is a hassle. Instead, you can use tfcompile (a XLA CLI tool) to compile a TensorFlow graph to executable machine code, and then deploy that as a microservice or native application."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# XLA\n",
"[XLA](https://www.tensorflow.org/performance/xla/) is a compiler of TensorFlow graphs.\n",
"\n",
"- TensorFlow's graph abstraction incurs overhead.\n",
"- XLA combats this so we can afford typing high-level code without relying on the existence of custom ops kernels.\n",
"- The compiler can be used for graph optimization during model training, but we'll focus on ahead-of-time (AOT) compilation for model deployment.\n",
"- Implementation is still maturing. XLA was released march last year and there are several commits per day."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"![image.png](https://2.bp.blogspot.com/-yhjY3pc6oow/WLRn2z4mPBI/AAAAAAAACcU/t_EAR6QMwQQkTBPftJQEonaB2DMbRXmXwCLcB/s640/Screen%2BShot%2B2017-02-27%2Bat%2B9.54.12%2BAM.png)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"![](https://www.tensorflow.org/images/how-does-xla-work.png)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Steps for ahead-of-time compiling a graph with XLA\n",
"We'll use the command-line tool tfcompile via Bazel.\n",
"1. Configure the subgraph to compile.\n",
"1. Use the tf_library build macro to compile the subgraph.\n",
"1. Write code to invoke the subgraph.\n",
"1. Create the final binary."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Step 0: Model\n",
"Before we start compiling a graph we need to build our graph. Let's keep it simple by just loading a pretrained image classifier."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: CUDA_VISIBLE_DEVICES=''\n"
]
}
],
"source": [
"# This cell can be safely removed and doesn't need to be run.\n",
"%env CUDA_VISIBLE_DEVICES=''\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"================================================================================\n",
"input_1 (InputLayer) (None, 224, 224, 0 \n",
"________________________________________________________________________________\n",
"conv1 (Conv2D) (None, 112, 112, 9472 input_1[0][0] \n",
"________________________________________________________________________________\n",
"bn_conv1 (BatchNormalizat (None, 112, 112, 256 conv1[0][0] \n",
"________________________________________________________________________________\n",
"activation_1 (Activation) (None, 112, 112, 0 bn_conv1[0][0] \n",
"________________________________________________________________________________\n",
"max_pooling2d_1 (MaxPooli (None, 55, 55, 64 0 activation_1[0][0] \n",
"________________________________________________________________________________\n",
"res2a_branch2a (Conv2D) (None, 55, 55, 64 4160 max_pooling2d_1[0][0] \n",
"________________________________________________________________________________\n",
"bn2a_branch2a (BatchNorma (None, 55, 55, 64 256 res2a_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_2 (Activation) (None, 55, 55, 64 0 bn2a_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res2a_branch2b (Conv2D) (None, 55, 55, 64 36928 activation_2[0][0] \n",
"________________________________________________________________________________\n",
"bn2a_branch2b (BatchNorma (None, 55, 55, 64 256 res2a_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_3 (Activation) (None, 55, 55, 64 0 bn2a_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res2a_branch2c (Conv2D) (None, 55, 55, 25 16640 activation_3[0][0] \n",
"________________________________________________________________________________\n",
"res2a_branch1 (Conv2D) (None, 55, 55, 25 16640 max_pooling2d_1[0][0] \n",
"________________________________________________________________________________\n",
"bn2a_branch2c (BatchNorma (None, 55, 55, 25 1024 res2a_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"bn2a_branch1 (BatchNormal (None, 55, 55, 25 1024 res2a_branch1[0][0] \n",
"________________________________________________________________________________\n",
"add_1 (Add) (None, 55, 55, 25 0 bn2a_branch2c[0][0] \n",
" bn2a_branch1[0][0] \n",
"________________________________________________________________________________\n",
"activation_4 (Activation) (None, 55, 55, 25 0 add_1[0][0] \n",
"________________________________________________________________________________\n",
"res2b_branch2a (Conv2D) (None, 55, 55, 64 16448 activation_4[0][0] \n",
"________________________________________________________________________________\n",
"bn2b_branch2a (BatchNorma (None, 55, 55, 64 256 res2b_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_5 (Activation) (None, 55, 55, 64 0 bn2b_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res2b_branch2b (Conv2D) (None, 55, 55, 64 36928 activation_5[0][0] \n",
"________________________________________________________________________________\n",
"bn2b_branch2b (BatchNorma (None, 55, 55, 64 256 res2b_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_6 (Activation) (None, 55, 55, 64 0 bn2b_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res2b_branch2c (Conv2D) (None, 55, 55, 25 16640 activation_6[0][0] \n",
"________________________________________________________________________________\n",
"bn2b_branch2c (BatchNorma (None, 55, 55, 25 1024 res2b_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_2 (Add) (None, 55, 55, 25 0 bn2b_branch2c[0][0] \n",
" activation_4[0][0] \n",
"________________________________________________________________________________\n",
"activation_7 (Activation) (None, 55, 55, 25 0 add_2[0][0] \n",
"________________________________________________________________________________\n",
"res2c_branch2a (Conv2D) (None, 55, 55, 64 16448 activation_7[0][0] \n",
"________________________________________________________________________________\n",
"bn2c_branch2a (BatchNorma (None, 55, 55, 64 256 res2c_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_8 (Activation) (None, 55, 55, 64 0 bn2c_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res2c_branch2b (Conv2D) (None, 55, 55, 64 36928 activation_8[0][0] \n",
"________________________________________________________________________________\n",
"bn2c_branch2b (BatchNorma (None, 55, 55, 64 256 res2c_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_9 (Activation) (None, 55, 55, 64 0 bn2c_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res2c_branch2c (Conv2D) (None, 55, 55, 25 16640 activation_9[0][0] \n",
"________________________________________________________________________________\n",
"bn2c_branch2c (BatchNorma (None, 55, 55, 25 1024 res2c_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_3 (Add) (None, 55, 55, 25 0 bn2c_branch2c[0][0] \n",
" activation_7[0][0] \n",
"________________________________________________________________________________\n",
"activation_10 (Activation (None, 55, 55, 25 0 add_3[0][0] \n",
"________________________________________________________________________________\n",
"res3a_branch2a (Conv2D) (None, 28, 28, 12 32896 activation_10[0][0] \n",
"________________________________________________________________________________\n",
"bn3a_branch2a (BatchNorma (None, 28, 28, 12 512 res3a_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_11 (Activation (None, 28, 28, 12 0 bn3a_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res3a_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_11[0][0] \n",
"________________________________________________________________________________\n",
"bn3a_branch2b (BatchNorma (None, 28, 28, 12 512 res3a_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_12 (Activation (None, 28, 28, 12 0 bn3a_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res3a_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_12[0][0] \n",
"________________________________________________________________________________\n",
"res3a_branch1 (Conv2D) (None, 28, 28, 51 131584 activation_10[0][0] \n",
"________________________________________________________________________________\n",
"bn3a_branch2c (BatchNorma (None, 28, 28, 51 2048 res3a_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"bn3a_branch1 (BatchNormal (None, 28, 28, 51 2048 res3a_branch1[0][0] \n",
"________________________________________________________________________________\n",
"add_4 (Add) (None, 28, 28, 51 0 bn3a_branch2c[0][0] \n",
" bn3a_branch1[0][0] \n",
"________________________________________________________________________________\n",
"activation_13 (Activation (None, 28, 28, 51 0 add_4[0][0] \n",
"________________________________________________________________________________\n",
"res3b_branch2a (Conv2D) (None, 28, 28, 12 65664 activation_13[0][0] \n",
"________________________________________________________________________________\n",
"bn3b_branch2a (BatchNorma (None, 28, 28, 12 512 res3b_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_14 (Activation (None, 28, 28, 12 0 bn3b_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res3b_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_14[0][0] \n",
"________________________________________________________________________________\n",
"bn3b_branch2b (BatchNorma (None, 28, 28, 12 512 res3b_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_15 (Activation (None, 28, 28, 12 0 bn3b_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res3b_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_15[0][0] \n",
"________________________________________________________________________________\n",
"bn3b_branch2c (BatchNorma (None, 28, 28, 51 2048 res3b_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_5 (Add) (None, 28, 28, 51 0 bn3b_branch2c[0][0] \n",
" activation_13[0][0] \n",
"________________________________________________________________________________\n",
"activation_16 (Activation (None, 28, 28, 51 0 add_5[0][0] \n",
"________________________________________________________________________________\n",
"res3c_branch2a (Conv2D) (None, 28, 28, 12 65664 activation_16[0][0] \n",
"________________________________________________________________________________\n",
"bn3c_branch2a (BatchNorma (None, 28, 28, 12 512 res3c_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_17 (Activation (None, 28, 28, 12 0 bn3c_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res3c_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_17[0][0] \n",
"________________________________________________________________________________\n",
"bn3c_branch2b (BatchNorma (None, 28, 28, 12 512 res3c_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_18 (Activation (None, 28, 28, 12 0 bn3c_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res3c_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_18[0][0] \n",
"________________________________________________________________________________\n",
"bn3c_branch2c (BatchNorma (None, 28, 28, 51 2048 res3c_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_6 (Add) (None, 28, 28, 51 0 bn3c_branch2c[0][0] \n",
" activation_16[0][0] \n",
"________________________________________________________________________________\n",
"activation_19 (Activation (None, 28, 28, 51 0 add_6[0][0] \n",
"________________________________________________________________________________\n",
"res3d_branch2a (Conv2D) (None, 28, 28, 12 65664 activation_19[0][0] \n",
"________________________________________________________________________________\n",
"bn3d_branch2a (BatchNorma (None, 28, 28, 12 512 res3d_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_20 (Activation (None, 28, 28, 12 0 bn3d_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res3d_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_20[0][0] \n",
"________________________________________________________________________________\n",
"bn3d_branch2b (BatchNorma (None, 28, 28, 12 512 res3d_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_21 (Activation (None, 28, 28, 12 0 bn3d_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res3d_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_21[0][0] \n",
"________________________________________________________________________________\n",
"bn3d_branch2c (BatchNorma (None, 28, 28, 51 2048 res3d_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_7 (Add) (None, 28, 28, 51 0 bn3d_branch2c[0][0] \n",
" activation_19[0][0] \n",
"________________________________________________________________________________\n",
"activation_22 (Activation (None, 28, 28, 51 0 add_7[0][0] \n",
"________________________________________________________________________________\n",
"res4a_branch2a (Conv2D) (None, 14, 14, 25 131328 activation_22[0][0] \n",
"________________________________________________________________________________\n",
"bn4a_branch2a (BatchNorma (None, 14, 14, 25 1024 res4a_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_23 (Activation (None, 14, 14, 25 0 bn4a_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res4a_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_23[0][0] \n",
"________________________________________________________________________________\n",
"bn4a_branch2b (BatchNorma (None, 14, 14, 25 1024 res4a_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_24 (Activation (None, 14, 14, 25 0 bn4a_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res4a_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_24[0][0] \n",
"________________________________________________________________________________\n",
"res4a_branch1 (Conv2D) (None, 14, 14, 10 525312 activation_22[0][0] \n",
"________________________________________________________________________________\n",
"bn4a_branch2c (BatchNorma (None, 14, 14, 10 4096 res4a_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"bn4a_branch1 (BatchNormal (None, 14, 14, 10 4096 res4a_branch1[0][0] \n",
"________________________________________________________________________________\n",
"add_8 (Add) (None, 14, 14, 10 0 bn4a_branch2c[0][0] \n",
" bn4a_branch1[0][0] \n",
"________________________________________________________________________________\n",
"activation_25 (Activation (None, 14, 14, 10 0 add_8[0][0] \n",
"________________________________________________________________________________\n",
"res4b_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_25[0][0] \n",
"________________________________________________________________________________\n",
"bn4b_branch2a (BatchNorma (None, 14, 14, 25 1024 res4b_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_26 (Activation (None, 14, 14, 25 0 bn4b_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res4b_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_26[0][0] \n",
"________________________________________________________________________________\n",
"bn4b_branch2b (BatchNorma (None, 14, 14, 25 1024 res4b_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_27 (Activation (None, 14, 14, 25 0 bn4b_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res4b_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_27[0][0] \n",
"________________________________________________________________________________\n",
"bn4b_branch2c (BatchNorma (None, 14, 14, 10 4096 res4b_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_9 (Add) (None, 14, 14, 10 0 bn4b_branch2c[0][0] \n",
" activation_25[0][0] \n",
"________________________________________________________________________________\n",
"activation_28 (Activation (None, 14, 14, 10 0 add_9[0][0] \n",
"________________________________________________________________________________\n",
"res4c_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_28[0][0] \n",
"________________________________________________________________________________\n",
"bn4c_branch2a (BatchNorma (None, 14, 14, 25 1024 res4c_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_29 (Activation (None, 14, 14, 25 0 bn4c_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res4c_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_29[0][0] \n",
"________________________________________________________________________________\n",
"bn4c_branch2b (BatchNorma (None, 14, 14, 25 1024 res4c_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_30 (Activation (None, 14, 14, 25 0 bn4c_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res4c_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_30[0][0] \n",
"________________________________________________________________________________\n",
"bn4c_branch2c (BatchNorma (None, 14, 14, 10 4096 res4c_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_10 (Add) (None, 14, 14, 10 0 bn4c_branch2c[0][0] \n",
" activation_28[0][0] \n",
"________________________________________________________________________________\n",
"activation_31 (Activation (None, 14, 14, 10 0 add_10[0][0] \n",
"________________________________________________________________________________\n",
"res4d_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_31[0][0] \n",
"________________________________________________________________________________\n",
"bn4d_branch2a (BatchNorma (None, 14, 14, 25 1024 res4d_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_32 (Activation (None, 14, 14, 25 0 bn4d_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res4d_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_32[0][0] \n",
"________________________________________________________________________________\n",
"bn4d_branch2b (BatchNorma (None, 14, 14, 25 1024 res4d_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_33 (Activation (None, 14, 14, 25 0 bn4d_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res4d_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_33[0][0] \n",
"________________________________________________________________________________\n",
"bn4d_branch2c (BatchNorma (None, 14, 14, 10 4096 res4d_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_11 (Add) (None, 14, 14, 10 0 bn4d_branch2c[0][0] \n",
" activation_31[0][0] \n",
"________________________________________________________________________________\n",
"activation_34 (Activation (None, 14, 14, 10 0 add_11[0][0] \n",
"________________________________________________________________________________\n",
"res4e_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_34[0][0] \n",
"________________________________________________________________________________\n",
"bn4e_branch2a (BatchNorma (None, 14, 14, 25 1024 res4e_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_35 (Activation (None, 14, 14, 25 0 bn4e_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res4e_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_35[0][0] \n",
"________________________________________________________________________________\n",
"bn4e_branch2b (BatchNorma (None, 14, 14, 25 1024 res4e_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_36 (Activation (None, 14, 14, 25 0 bn4e_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res4e_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_36[0][0] \n",
"________________________________________________________________________________\n",
"bn4e_branch2c (BatchNorma (None, 14, 14, 10 4096 res4e_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_12 (Add) (None, 14, 14, 10 0 bn4e_branch2c[0][0] \n",
" activation_34[0][0] \n",
"________________________________________________________________________________\n",
"activation_37 (Activation (None, 14, 14, 10 0 add_12[0][0] \n",
"________________________________________________________________________________\n",
"res4f_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_37[0][0] \n",
"________________________________________________________________________________\n",
"bn4f_branch2a (BatchNorma (None, 14, 14, 25 1024 res4f_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_38 (Activation (None, 14, 14, 25 0 bn4f_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res4f_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_38[0][0] \n",
"________________________________________________________________________________\n",
"bn4f_branch2b (BatchNorma (None, 14, 14, 25 1024 res4f_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_39 (Activation (None, 14, 14, 25 0 bn4f_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res4f_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_39[0][0] \n",
"________________________________________________________________________________\n",
"bn4f_branch2c (BatchNorma (None, 14, 14, 10 4096 res4f_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_13 (Add) (None, 14, 14, 10 0 bn4f_branch2c[0][0] \n",
" activation_37[0][0] \n",
"________________________________________________________________________________\n",
"activation_40 (Activation (None, 14, 14, 10 0 add_13[0][0] \n",
"________________________________________________________________________________\n",
"res5a_branch2a (Conv2D) (None, 7, 7, 512) 524800 activation_40[0][0] \n",
"________________________________________________________________________________\n",
"bn5a_branch2a (BatchNorma (None, 7, 7, 512) 2048 res5a_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_41 (Activation (None, 7, 7, 512) 0 bn5a_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res5a_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_41[0][0] \n",
"________________________________________________________________________________\n",
"bn5a_branch2b (BatchNorma (None, 7, 7, 512) 2048 res5a_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_42 (Activation (None, 7, 7, 512) 0 bn5a_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res5a_branch2c (Conv2D) (None, 7, 7, 2048 1050624 activation_42[0][0] \n",
"________________________________________________________________________________\n",
"res5a_branch1 (Conv2D) (None, 7, 7, 2048 2099200 activation_40[0][0] \n",
"________________________________________________________________________________\n",
"bn5a_branch2c (BatchNorma (None, 7, 7, 2048 8192 res5a_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"bn5a_branch1 (BatchNormal (None, 7, 7, 2048 8192 res5a_branch1[0][0] \n",
"________________________________________________________________________________\n",
"add_14 (Add) (None, 7, 7, 2048 0 bn5a_branch2c[0][0] \n",
" bn5a_branch1[0][0] \n",
"________________________________________________________________________________\n",
"activation_43 (Activation (None, 7, 7, 2048 0 add_14[0][0] \n",
"________________________________________________________________________________\n",
"res5b_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_43[0][0] \n",
"________________________________________________________________________________\n",
"bn5b_branch2a (BatchNorma (None, 7, 7, 512) 2048 res5b_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_44 (Activation (None, 7, 7, 512) 0 bn5b_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res5b_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_44[0][0] \n",
"________________________________________________________________________________\n",
"bn5b_branch2b (BatchNorma (None, 7, 7, 512) 2048 res5b_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_45 (Activation (None, 7, 7, 512) 0 bn5b_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res5b_branch2c (Conv2D) (None, 7, 7, 2048 1050624 activation_45[0][0] \n",
"________________________________________________________________________________\n",
"bn5b_branch2c (BatchNorma (None, 7, 7, 2048 8192 res5b_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_15 (Add) (None, 7, 7, 2048 0 bn5b_branch2c[0][0] \n",
" activation_43[0][0] \n",
"________________________________________________________________________________\n",
"activation_46 (Activation (None, 7, 7, 2048 0 add_15[0][0] \n",
"________________________________________________________________________________\n",
"res5c_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_46[0][0] \n",
"________________________________________________________________________________\n",
"bn5c_branch2a (BatchNorma (None, 7, 7, 512) 2048 res5c_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"activation_47 (Activation (None, 7, 7, 512) 0 bn5c_branch2a[0][0] \n",
"________________________________________________________________________________\n",
"res5c_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_47[0][0] \n",
"________________________________________________________________________________\n",
"bn5c_branch2b (BatchNorma (None, 7, 7, 512) 2048 res5c_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"activation_48 (Activation (None, 7, 7, 512) 0 bn5c_branch2b[0][0] \n",
"________________________________________________________________________________\n",
"res5c_branch2c (Conv2D) (None, 7, 7, 2048 1050624 activation_48[0][0] \n",
"________________________________________________________________________________\n",
"bn5c_branch2c (BatchNorma (None, 7, 7, 2048 8192 res5c_branch2c[0][0] \n",
"________________________________________________________________________________\n",
"add_16 (Add) (None, 7, 7, 2048 0 bn5c_branch2c[0][0] \n",
" activation_46[0][0] \n",
"________________________________________________________________________________\n",
"activation_49 (Activation (None, 7, 7, 2048 0 add_16[0][0] \n",
"________________________________________________________________________________\n",
"avg_pool (AveragePooling2 (None, 1, 1, 2048 0 activation_49[0][0] \n",
"________________________________________________________________________________\n",
"flatten_1 (Flatten) (None, 2048) 0 avg_pool[0][0] \n",
"________________________________________________________________________________\n",
"fc1000 (Dense) (None, 1000) 2049000 flatten_1[0][0] \n",
"================================================================================\n",
"Total params: 25,636,712\n",
"Trainable params: 25,583,592\n",
"Non-trainable params: 53,120\n",
"________________________________________________________________________________\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"\n",
"tf.keras.backend.set_learning_phase(False)\n",
"model = tf.keras.applications.ResNet50()\n",
"model.summary(80)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Step 0.5: Download tfcompile\n",
"XLA is still maturing and as of now we have to checkout the development release. System prerequisites are git, the build tool [Bazel](https://docs.bazel.build) and the [Protocol Buffers](https://developers.google.com/protocol-buffers) compiler. I'm also assuming we're running tf-nightly which can be installed via pip."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"%rm -rf /tmp/tensorflow"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/tmp\n",
"Cloning into 'tensorflow'...\n",
"remote: Counting objects: 10580, done.\u001b[K\n",
"remote: Compressing objects: 100% (8825/8825), done.\u001b[K\n",
"remote: Total 10580 (delta 3329), reused 3594 (delta 1486), pack-reused 0\u001b[K\n",
"Receiving objects: 100% (10580/10580), 21.65 MiB | 4.71 MiB/s, done.\n",
"Resolving deltas: 100% (3329/3329), done.\n",
"/tmp/tensorflow\n",
"WARNING: Running Bazel server needs to be killed, because the startup options are different.\n",
"You have bazel 0.8.1 installed.\n",
"Please specify the location of python. [Default is /home/carl/anaconda3/bin/python]: \n",
"\n",
"Found possible Python library paths:\n",
" /home/carl/anaconda3/lib/python3.6/site-packages\n",
"Please input the desired Python library path to use. Default is [/home/carl/anaconda3/lib/python3.6/site-packages]\n",
"Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]: jemalloc as malloc support will be enabled for TensorFlow.\n",
"\n",
"Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: Google Cloud Platform support will be enabled for TensorFlow.\n",
"\n",
"Do you wish to build TensorFlow with Hadoop File System support? [Y/n]: Hadoop File System support will be enabled for TensorFlow.\n",
"\n",
"Do you wish to build TensorFlow with Amazon S3 File System support? [Y/n]: Amazon S3 File System support will be enabled for TensorFlow.\n",
"\n",
"Do you wish to build TensorFlow with XLA JIT support? [y/N]: No XLA JIT support will be enabled for TensorFlow.\n",
"\n",
"Do you wish to build TensorFlow with GDR support? [y/N]: No GDR support will be enabled for TensorFlow.\n",
"\n",
"Do you wish to build TensorFlow with VERBS support? [y/N]: No VERBS support will be enabled for TensorFlow.\n",
"\n",
"Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: No OpenCL SYCL support will be enabled for TensorFlow.\n",
"\n",
"Do you wish to build TensorFlow with CUDA support? [y/N]: No CUDA support will be enabled for TensorFlow.\n",
"\n",
"Do you wish to build TensorFlow with MPI support? [y/N]: No MPI support will be enabled for TensorFlow.\n",
"\n",
"Please specify optimization flags to use during compilation when bazel option \"--config=opt\" is specified [Default is -march=native]: \n",
"\n",
"Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: Not configuring the WORKSPACE for Android builds.\n",
"\n",
"Preconfigured Bazel build configs. You can use any of the below by adding \"--config=<>\" to your build command. See tools/bazel.rc for more details.\n",
"\t--config=mkl \t# Build with MKL support.\n",
"\t--config=monolithic \t# Config for mostly static monolithic build.\n",
"Configuration finished\n",
"yes: standard output: Broken pipe\n"
]
}
],
"source": [
"%cd /tmp\n",
"!git clone --depth=1 --single-branch https://github.com/tensorflow/tensorflow\n",
"%cd tensorflow\n",
"!yes \"\" | ./configure\n",
"!protoc tensorflow/compiler/tf2xla/tf2xla.proto --python_out=.\n",
"!cp tensorflow/compiler/tf2xla/tf2xla_pb2.py ."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Step 1: Configure the subgraph to compile."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### List feeds and fetches\n",
"tfcompile needs static input shapes so we have to pick a batch size for our image classifier."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import tf2xla_pb2\n",
"\n",
"config = tf2xla_pb2.Config()\n",
"\n",
"batch_size = 1\n",
"\n",
"for x in model.inputs:\n",
" x.set_shape([batch_size] + list(x.shape)[1:])\n",
" feed = config.feed.add()\n",
" feed.id.node_name = x.op.name\n",
" feed.shape.MergeFrom(x.shape.as_proto())\n",
"\n",
"for x in model.outputs:\n",
" fetch = config.fetch.add()\n",
" fetch.id.node_name = x.op.name\n",
"\n",
"with open('graph.config.pbtxt', 'w') as f:\n",
" f.write(str(config))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"feed {\r\n",
" id {\r\n",
" node_name: \"input_1\"\r\n",
" }\r\n",
" shape {\r\n",
" dim {\r\n",
" size: 1\r\n",
" }\r\n",
" dim {\r\n",
" size: 224\r\n",
" }\r\n",
" dim {\r\n",
" size: 224\r\n",
" }\r\n",
" dim {\r\n",
" size: 3\r\n",
" }\r\n",
" }\r\n",
"}\r\n",
"fetch {\r\n",
" id {\r\n",
" node_name: \"fc1000/Softmax\"\r\n",
" }\r\n",
"}\r\n"
]
}
],
"source": [
"cat graph.config.pbtxt"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Freeze graph\n",
"The graph contains mutable nodes that have to be constants. It's possible to let tfcompile handle this for you (via [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)) by providing a weights checkpoint along with the graph definition, but as we already have everything loaded we'll make them into constants right away."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Froze 320 variables.\n",
"Converted 320 variables to const ops.\n"
]
},
{
"data": {
"text/plain": [
"'./graph.pb'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"session = tf.keras.backend.get_session()\n",
"output_node_names = [node.op.name for node in model.outputs]\n",
"graphdef = tf.graph_util.convert_variables_to_constants(session, session.graph_def, output_node_names)\n",
"tf.train.write_graph(graphdef, '.', 'graph.pb', as_text=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Step 2: Use the tf_library build macro to compile the subgraph."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Overwriting BUILD\n"
]
}
],
"source": [
"%%writefile BUILD\n",
"\n",
"load('@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl', 'tf_library')\n",
"\n",
"tf_library(\n",
" name = 'graph',\n",
" config = 'graph.config.pbtxt',\n",
" cpp_class = 'Graph',\n",
" graph = 'graph.pb',\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
".......\n",
"\u001b[32mLoading:\u001b[0m \n",
"\u001b[1A\u001b[K\u001b[32mLoading:\u001b[0m 0 packages loaded\n",
"\u001b[1A\u001b[K\u001b[35mWARNING: \u001b[0m/home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/core/BUILD:1816:1: in includes attribute of cc_library rule @org_tensorflow//tensorflow/core:framework_headers_lib: '../../../../external/nsync/public' resolves to 'external/nsync/public' not below the relative path of its package 'external/org_tensorflow/tensorflow/core'. This will be an error in the future. Since this rule was created by the macro 'cc_header_only_library', the error might have been caused by the macro implementation in /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/tensorflow.bzl:1143:30\n",
"\u001b[32mAnalyzing:\u001b[0m target @org_tensorflow//:graph (68 packages loaded)\n",
"\u001b[1A\u001b[K\u001b[32mINFO: \u001b[0mAnalysed target @org_tensorflow//:graph (74 packages loaded).\n",
"\u001b[32mBuilding:\u001b[0m no action running\n",
"\u001b[1A\u001b[K\u001b[32mINFO: \u001b[0mFound 1 target...\n",
"\u001b[32mBuilding:\u001b[0m no action running\n",
"\u001b[1A\u001b[K\u001b[32m[0 / 6]\u001b[0m BazelWorkspaceStatusAction stable-status.txt\n",
"\u001b[1A\u001b[K\u001b[32mINFO: \u001b[0mFrom Executing genrule @org_tensorflow//tensorflow/core:version_info_gen [for host]:\n",
"\u001b[32m[1,674 / 3,309]\u001b[0m @org_tensorflow//tensorflow/core:version_info_gen; 0s local\n",
"\u001b[1A\u001b[Kfatal: No names found, cannot describe anything.\n",
"\u001b[32m[1,674 / 3,309]\u001b[0m @org_tensorflow//tensorflow/core:version_info_gen; 0s local\n",
"\u001b[1A\u001b[K\u001b[32mINFO: \u001b[0mFrom Executing genrule @org_tensorflow//:gen_graph:\n",
"\u001b[32m[3,332 / 3,336]\u001b[0m Executing genrule @org_tensorflow//:gen_graph; 47s local\n",
"\u001b[1A\u001b[K2018-01-11 15:27:20.408071: I external/org_tensorflow/tensorflow/core/platform/s3/aws_logging.cc:53] Initializing Curl library\n",
"2018-01-11 15:27:20.514752: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA\n",
"\u001b[32m[3,332 / 3,336]\u001b[0m Executing genrule @org_tensorflow//:gen_graph; 47s local\n",
"\u001b[1A\u001b[KTarget @org_tensorflow//:graph up-to-date:\n",
"\u001b[32m[3,336 / 3,336]\u001b[0m no action running\n",
"\u001b[1A\u001b[K bazel-bin/external/org_tensorflow/libgraph.a\n",
"\u001b[32m[3,336 / 3,336]\u001b[0m no action running\n",
"\u001b[1A\u001b[K bazel-bin/external/org_tensorflow/libgraph.pic.a\n",
"\u001b[32m[3,336 / 3,336]\u001b[0m no action running\n",
"\u001b[1A\u001b[K bazel-bin/external/org_tensorflow/libgraph.so\n",
"\u001b[32m[3,336 / 3,336]\u001b[0m no action running\n",
"\u001b[1A\u001b[K\u001b[32mINFO: \u001b[0mElapsed time: 57.837s, Critical Path: 50.33s\n",
"\u001b[32m[3,336 / 3,336]\u001b[0m no action running\n",
"\u001b[1A\u001b[K\u001b[32mINFO:\u001b[0m Build completed successfully, 3 total actions\n",
"\u001b[0m"
]
}
],
"source": [
"!bazel build --show_progress_rate_limit=600 @org_tensorflow//:graph"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT!\r\n",
"//\r\n",
"// This header was generated via ahead-of-time compilation of a TensorFlow\r\n",
"// graph. An object file corresponding to this header was also generated.\r\n",
"// This header gives access to the functionality in that object file.\r\n",
"//\r\n",
"// clang-format off\r\n",
"\r\n",
"#ifndef TFCOMPILE_GENERATED_____graph_H_ // NOLINT(build/header_guard)\r\n",
"#define TFCOMPILE_GENERATED_____graph_H_ // NOLINT(build/header_guard)\r\n",
"\r\n",
"\r\n",
"#include \"tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h\"\r\n",
"#include \"tensorflow/core/platform/types.h\"\r\n",
"\r\n",
"namespace Eigen { struct ThreadPoolDevice; }\r\n",
"namespace xla { class ExecutableRunOptions; }\r\n",
"\r\n",
"// (Implementation detail) Entry point to the function in the object file.\r\n",
"extern \"C\" void ____graph(\r\n",
" void* result, const xla::ExecutableRunOptions* run_options,\r\n",
" const void** args, void** temps, tensorflow::int64* profile_counters);\r\n",
"\r\n",
"\r\n",
"// Graph represents a computation previously specified in a\r\n",
"// TensorFlow graph, now compiled into executable code. This extends the generic\r\n",
"// XlaCompiledCpuFunction class with statically type-safe arg and result\r\n",
"// methods. Usage example:\r\n",
"//\r\n",
"// Graph computation;\r\n",
"// // ...set args using computation.argN methods\r\n",
"// CHECK(computation.Run());\r\n",
"// // ...inspect results using computation.resultN methods\r\n",
"//\r\n",
"// The Run method invokes the actual computation, with inputs read from arg\r\n",
"// buffers, and outputs written to result buffers. Each Run call may also use\r\n",
"// a set of temporary buffers for the computation.\r\n",
"//\r\n",
"// By default each instance of this class manages its own arg, result and temp\r\n",
"// buffers. The AllocMode constructor parameter may be used to modify the\r\n",
"// buffer allocation strategy.\r\n",
"//\r\n",
"// Under the default allocation strategy, this class is thread-compatible:\r\n",
"// o Calls to non-const methods require exclusive access to the object.\r\n",
"// o Concurrent calls to const methods are OK, if those calls are made while it\r\n",
"// is guaranteed that no thread may call a non-const method.\r\n",
"//\r\n",
"// The logical function signature is:\r\n",
"// (arg0: f32[1,224,224,3]) -> (f32[1,1000])\r\n",
"//\r\n",
"// Memory stats:\r\n",
"// arg bytes total: 602112\r\n",
"// arg bytes aligned: 602112\r\n",
"// temp bytes total: 17815208\r\n",
"// temp bytes aligned: 17815232\r\n",
"class Graph : public tensorflow::XlaCompiledCpuFunction {\r\n",
" public:\r\n",
" // Number of input arguments for the compiled computation.\r\n",
" static constexpr size_t kNumArgs = 1;\r\n",
"\r\n",
" // Byte size of each argument buffer. There are kNumArgs entries.\r\n",
" static const intptr_t* ArgSizes() {\r\n",
" static constexpr intptr_t kArgSizes[kNumArgs] = {602112};\r\n",
" return kArgSizes;\r\n",
" }\r\n",
"\r\n",
" // Returns static data used to create an XlaCompiledCpuFunction.\r\n",
" static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() {\r\n",
" static XlaCompiledCpuFunction::StaticData* kStaticData = [](){\r\n",
" XlaCompiledCpuFunction::StaticData* data =\r\n",
" new XlaCompiledCpuFunction::StaticData;\r\n",
" data->raw_function = ____graph;\r\n",
" data->arg_sizes = ArgSizes();\r\n",
" data->num_args = kNumArgs;\r\n",
" data->temp_sizes = TempSizes();\r\n",
" data->num_temps = kNumTemps;\r\n",
" data->result_index = kResultIndex;\r\n",
" data->arg_names = StaticArgNames();\r\n",
" data->result_names = StaticResultNames();\r\n",
" data->program_shape = StaticProgramShape();\r\n",
" return data;\r\n",
" }();\r\n",
" return *kStaticData;\r\n",
" }\r\n",
"\r\n",
" Graph(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS)\r\n",
" : XlaCompiledCpuFunction(StaticData(), alloc_mode) {}\r\n",
"\r\n",
" Graph(const Graph&) = delete;\r\n",
" Graph& operator=(const Graph&) = delete;\r\n",
"\r\n",
" // Arg methods for managing input buffers. Buffers are in row-major order.\r\n",
" // There is a set of methods for each positional argument, with the following\r\n",
" // general form:\r\n",
" //\r\n",
" // void set_argN_data(void* data)\r\n",
" // Sets the buffer of type T for positional argument N. May be called in\r\n",
" // any AllocMode. Must be called before Run to have an affect. Must be\r\n",
" // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional\r\n",
" // argument, to set the argument buffers.\r\n",
" //\r\n",
" // T* argN_data()\r\n",
" // Returns the buffer of type T for positional argument N.\r\n",
" //\r\n",
" // T& argN(...dim indices...)\r\n",
" // Returns a reference to the value of type T for positional argument N,\r\n",
" // with dim indices specifying which value. No bounds checking is performed\r\n",
" // on dim indices.\r\n",
"\r\n",
" void set_arg0_data(void* data) {\r\n",
" set_arg_data(0, data);\r\n",
" }\r\n",
" float* arg0_data() {\r\n",
" return static_cast<float*>(arg_data(0));\r\n",
" }\r\n",
" float& arg0(size_t dim0, size_t dim1, size_t dim2, size_t dim3) {\r\n",
" return (*static_cast<float(*)[1][224][224][3]>(\r\n",
" arg_data(0)))[dim0][dim1][dim2][dim3];\r\n",
" }\r\n",
" const float* arg0_data() const {\r\n",
" return static_cast<const float*>(arg_data(0));\r\n",
" }\r\n",
" const float& arg0(size_t dim0, size_t dim1, size_t dim2, size_t dim3) const {\r\n",
" return (*static_cast<const float(*)[1][224][224][3]>(\r\n",
" arg_data(0)))[dim0][dim1][dim2][dim3];\r\n",
" }\r\n",
"\r\n",
" // Result methods for managing output buffers. Buffers are in row-major order.\r\n",
" // Must only be called after a successful Run call. There is a set of methods\r\n",
" // for each positional result, with the following general form:\r\n",
" //\r\n",
" // T* resultN_data()\r\n",
" // Returns the buffer of type T for positional result N.\r\n",
" //\r\n",
" // T& resultN(...dim indices...)\r\n",
" // Returns a reference to the value of type T for positional result N,\r\n",
" // with dim indices specifying which value. No bounds checking is performed\r\n",
" // on dim indices.\r\n",
" //\r\n",
" // Unlike the arg methods, there is no set_resultN_data method. The result\r\n",
" // buffers are managed internally, and may change after each call to Run.\r\n",
"\r\n",
" float* result0_data() {\r\n",
" return static_cast<float*>(result_data(0));\r\n",
" }\r\n",
" float& result0(size_t dim0, size_t dim1) {\r\n",
" return (*static_cast<float(*)[1][1000]>(\r\n",
" result_data(0)))[dim0][dim1];\r\n",
" }\r\n",
" const float* result0_data() const {\r\n",
" return static_cast<const float*>(result_data(0));\r\n",
" }\r\n",
" const float& result0(size_t dim0, size_t dim1) const {\r\n",
" return (*static_cast<const float(*)[1][1000]>(\r\n",
" result_data(0)))[dim0][dim1];\r\n",
" }\r\n",
"\r\n",
" private:\r\n",
" // Number of result and temporary buffers for the compiled computation.\r\n",
" static constexpr size_t kNumTemps = 10;\r\n",
" // The 0-based index of the result tuple in the temporary buffers.\r\n",
" static constexpr size_t kResultIndex = 2;\r\n",
"\r\n",
" // Byte size of each result / temporary buffer. There are kNumTemps entries.\r\n",
" static const intptr_t* TempSizes() {\r\n",
" static constexpr intptr_t kTempSizes[kNumTemps] = {-1, 4000, 8, -1, -1, -1, -1, -1, -1, 17811200};\r\n",
" return kTempSizes;\r\n",
" }\r\n",
"\r\n",
" // Array of names of each positional argument, terminated by nullptr.\r\n",
" static const char** StaticArgNames() {\r\n",
" return nullptr;\r\n",
" }\r\n",
"\r\n",
" // Array of names of each positional result, terminated by nullptr.\r\n",
" static const char** StaticResultNames() {\r\n",
" return nullptr;\r\n",
" }\r\n",
"\r\n",
" // Shape of the args and results.\r\n",
" static const xla::ProgramShape* StaticProgramShape() {\r\n",
" return nullptr;\r\n",
" }\r\n",
"};\r\n",
"\r\n",
"\r\n",
"#endif // TFCOMPILE_GENERATED_____graph_H_\r\n",
"\r\n",
"// clang-format on\r\n"
]
}
],
"source": [
"cat bazel-genfiles/graph.h"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Step 3: Write code to invoke the subgraph."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Writing graph.cc\n"
]
}
],
"source": [
"%%writefile graph.cc\n",
"\n",
"#define EIGEN_USE_THREADS\n",
"#define EIGEN_USE_CUSTOM_THREAD_POOL\n",
"\n",
"#include \"graph.h\"\n",
"#include \"third_party/eigen3/unsupported/Eigen/CXX11/Tensor\"\n",
"\n",
"extern \"C\" int run(float *input, float *output, int input_size, int output_size) {\n",
" Eigen::ThreadPool tp(std::thread::hardware_concurrency());\n",
" Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());\n",
" Graph graph;\n",
" graph.set_thread_pool(&device);\n",
"\n",
" std::copy(input, input + input_size, graph.arg0_data());\n",
" auto ok = graph.Run();\n",
" if (not ok) return -1;\n",
" std::copy(graph.result0_data(), graph.result0_data() + output_size, output);\n",
" return 0;\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Step 4: Create the final binary.\n",
"Instead of calling `gcc` directly, and as Bazel is already required for building the tfcompile tool, we'll make a `cc_binary` rule. In fact, we could just have done one big BUILD file directly after having cloned the TensorFlow repo."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Appending to BUILD\n"
]
}
],
"source": [
"%%writefile -a BUILD\n",
"\n",
"cc_binary(\n",
" name = \"libmodel.so\",\n",
" srcs = [\"graph.cc\"],\n",
" deps = [\":graph\", \"//third_party/eigen3\"],\n",
" linkopts = [\"-lpthread\"],\n",
" linkshared = 1,\n",
" copts = [\"-fPIC\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32mLoading:\u001b[0m \n",
"\u001b[1A\u001b[K\u001b[32mLoading:\u001b[0m 0 packages loaded\n",
"\u001b[1A\u001b[K\u001b[35mWARNING: \u001b[0m/home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/core/BUILD:1816:1: in includes attribute of cc_library rule @org_tensorflow//tensorflow/core:framework_headers_lib: '../../../../external/nsync/public' resolves to 'external/nsync/public' not below the relative path of its package 'external/org_tensorflow/tensorflow/core'. This will be an error in the future. Since this rule was created by the macro 'cc_header_only_library', the error might have been caused by the macro implementation in /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/tensorflow.bzl:1143:30\n",
"\u001b[32mAnalyzing:\u001b[0m target @org_tensorflow//:libmodel.so (2 packages loaded)\n",
"\u001b[1A\u001b[K\u001b[32mINFO: \u001b[0mAnalysed target @org_tensorflow//:libmodel.so (2 packages loaded).\n",
"\u001b[32mBuilding:\u001b[0m no action running\n",
"\u001b[1A\u001b[K\u001b[32mINFO: \u001b[0mFound 1 target...\n",
"\u001b[32mBuilding:\u001b[0m no action running\n",
"\u001b[1A\u001b[K\u001b[32m[0 / 5]\u001b[0m BazelWorkspaceStatusAction stable-status.txt\n",
"\u001b[1A\u001b[KTarget @org_tensorflow//:libmodel.so up-to-date:\n",
"\u001b[32m[632 / 632]\u001b[0m no action running\n",
"\u001b[1A\u001b[K bazel-bin/external/org_tensorflow/libmodel.so\n",
"\u001b[32m[632 / 632]\u001b[0m no action running\n",
"\u001b[1A\u001b[K\u001b[32mINFO: \u001b[0mElapsed time: 1.852s, Critical Path: 0.56s\n",
"\u001b[32m[632 / 632]\u001b[0m no action running\n",
"\u001b[1A\u001b[K\u001b[32mINFO:\u001b[0m Build completed successfully, 1 total action\n",
"\u001b[0m"
]
}
],
"source": [
"!bazel build --show_progress_rate_limit=60 @org_tensorflow//:libmodel.so"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"libmodel = np.ctypeslib.load_library('libmodel', 'bazel-bin/external/org_tensorflow')\n",
"libmodel.run.argtypes = [\n",
" np.ctypeslib.ndpointer(np.float32, ndim=4, shape=(1, 224, 224, 3), flags=('c', 'a')),\n",
" np.ctypeslib.ndpointer(np.float32, ndim=2, shape=(1, 1000), flags=('c', 'a', 'w')),\n",
" np.ctypeslib.ctypes.c_int,\n",
" np.ctypeslib.ctypes.c_int]\n",
"\n",
"\n",
"def predict(x):\n",
" x = np.require(x, np.float32, ('c', 'a'))\n",
" y = np.require(np.zeros((1, 1000)), np.float32, ('c', 'a', 'w'))\n",
" libmodel.run(x, y, x.size, y.size)\n",
" return y"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[('n02110806', 'basenji', 0.60816735),\n",
" ('n02441942', 'weasel', 0.10849755),\n",
" ('n02091244', 'Ibizan_hound', 0.081580825),\n",
" ('n02124075', 'Egyptian_cat', 0.044705715),\n",
" ('n02123597', 'Siamese_cat', 0.025189402)]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from keras.preprocessing import image\n",
"from keras.applications.imagenet_utils import preprocess_input, decode_predictions\n",
"\n",
"image_path = input()\n",
"\n",
"x = image.img_to_array(image.load_img(image_path, target_size=(224, 224)))\n",
"x = x[None, ...]\n",
"x = preprocess_input(x)\n",
"y = predict(x)\n",
"decode_predictions(y)[0]"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"150 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
"191 ms ± 604 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%timeit model.predict(x)\n",
"%timeit predict(x)\n",
"np.testing.assert_allclose(model.predict(x), predict(x), atol=1e-5)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.96 s ± 456 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"model = tf.keras.applications.ResNet50()\n",
"model.predict(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# References\n",
"- https://www.tensorflow.org/performance/xla/tfcompile\n",
"- https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html\n",
"- https://youtu.be/kAOanJczHA0\n",
"- https://youtu.be/2IOPpyyuLkc"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@bjs3118
Copy link

bjs3118 commented Jul 26, 2022

Do you know if there is anyway to build the model into a static library?

@powderluv
Copy link
Author

I would use IREE or SHARK these days instead of XLA for CPU and GPU to create static executables

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment