Skip to content

Instantly share code, notes, and snippets.

@martinsbruveris
Created November 15, 2020 22:03
Show Gist options
  • Save martinsbruveris/1ce43d4fe36f40e29e1f69fd036f1626 to your computer and use it in GitHub Desktop.
Save martinsbruveris/1ce43d4fe36f40e29e1f69fd036f1626 to your computer and use it in GitHub Desktop.
How to split tensorflow models into two at a given layer.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Surgery"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras import Input, Model\n",
"from tensorflow.keras.models import clone_model, Sequential\n",
"from tensorflow.keras.layers import Dense\n",
"from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def build_model():\n",
" model = Sequential()\n",
" model.add(Input((2,)))\n",
" model.add(Dense(4, name=\"fc1\"))\n",
" model.add(Dense(8, name=\"fc2\"))\n",
" model.add(Dense(16, name=\"fc3\"))\n",
" model.add(Dense(2, name=\"fc4\"))\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Keras' functional API makes it easy to combine models into a new model. All we have to do is the following."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"model_a = build_model()\n",
"model_b = build_model()\n",
"\n",
"x = Input((2, ))\n",
"y = model_b(model_a(x))\n",
"model = Model(x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we can test that we get the right answer."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([[-0.62025166 1.3668724 ]], shape=(1, 2), dtype=float32)\n",
"tf.Tensor([[-0.62025166 1.3668724 ]], shape=(1, 2), dtype=float32)\n"
]
}
],
"source": [
"x = tf.constant([[1., 2.]])\n",
"y1 = model(x)\n",
"y2 = model_b(model_a(x))\n",
"print(y1)\n",
"print(y2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But what about the inverse problem? How easy is it to split a given model into two? For example, we would like to split\n",
"```\n",
"model = build_model()\n",
"```\n",
"into `model_a` and `model_b` such that `model_a` contains the first two layers and `model_b` the last two. How would we go about doing that?\n",
"\n",
"It turns out that there is no simple solution to this problem. Or at least I could not find one. Instead I found two approaches that work, but come with some caveats that might be a problem in applications."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Keras functional API"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following will fail, because inputs to functional models must come from `tf.keras.Input`. We cannot use arbitrary tensors as inputs as we are attempting here with `cut_layer.output`. If executed, the following cell will throw an error when attempting to construct `model_b`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = build_model()\n",
"\n",
"cut_layer = model.get_layer(name=\"fc2\")\n",
"model_a = Model(model.inputs, cut_layer.output)\n",
"model_b = Model(cut_layer.output, model.outputs)\n",
"\n",
"# WARNING:tensorflow:Functional inputs must come from `tf.keras.Input` \n",
"# (thus holding past layer metadata), they cannot be the output of a \n",
"# previous non-Input layer. Here, a tensor specified as input to \n",
"# \"functional_4\" was not an Input tensor, it was generated by layer fc2.\n",
"# Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.\n",
"# The tensor that caused the issue was: fc2/BiasAdd_2:0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can attempt to fix the problem by tranforming `cut_layer.output` to a `tf.keras.Input`. That works to some extent. Note by using the `tensor` argument in `Input`, we don't create a new tensor, but reuse the existing one."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"model = build_model()\n",
"\n",
"cut_layer = model.get_layer(name=\"fc2\")\n",
"model_a = Model(model.inputs, cut_layer.output)\n",
"model_b = Model(\n",
" Input(tensor=cut_layer.output), model.outputs\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see that we have indeed sucessfully cut `model` into two submodels, `model_a` and `model_b`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([[-0.36279273 -0.7065454 ]], shape=(1, 2), dtype=float32)\n",
"tf.Tensor([[-0.36279273 -0.7065454 ]], shape=(1, 2), dtype=float32)\n"
]
}
],
"source": [
"x = tf.constant([[1., 2.]])\n",
"y1 = model(x)\n",
"y2 = model_b(model_a(x))\n",
"print(y1)\n",
"print(y2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What is the problem with this approach? By converting the tensor `cut_layer.output` to an `Input`, it has internally been converted to a placeholder (yes, they still exist in TF2), which leads to problems, if we want to cut the same model again. If executed, the following will result in an error."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Note that we are re-using the model from above\n",
"cut_layer = model.get_layer(name=\"fc3\")\n",
"model_c = Model(model.inputs, cut_layer.output)\n",
"\n",
"# ValueError: Graph disconnected: cannot obtain value for tensor \n",
"# Tensor(\"fc2/BiasAdd_3:0\", shape=(None, 8), dtype=float32) at layer \"fc3\". \n",
"# The following previous layers were accessed without issue: []"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Cloning models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One way to overcome the problem described above is by cloning the model first and performing surgery on the clone. This strategy will work, provided that the model can be cloned using `tf.keras.models.clone_model`. Cloning might fail for subclassed models as well as for models with custom layers, that have not implemented `get_config` and `from_config` methods.\n",
"\n",
"But for models implemented with the functional API that use only standard Keras layers, this is a viable strategy."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def cut_model(model, cut_layer_name):\n",
" clone = clone_model(model)\n",
" clone.set_weights(model.get_weights())\n",
" cut_layer = clone.get_layer(name=cut_layer_name)\n",
" model_a = Model(clone.inputs, cut_layer.output)\n",
" model_b = Model(\n",
" Input(tensor=cut_layer.output), clone.outputs\n",
" )\n",
" return model_a, model_b"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"model = build_model()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.29277328 -1.2625693 ]] [[ 0.29277328 -1.2625693 ]]\n"
]
}
],
"source": [
"model_a, model_b = cut_model(model, \"fc2\")\n",
"\n",
"x = tf.constant([[1., 2.]])\n",
"y1 = model(x)\n",
"y2 = model_b(model_a(x))\n",
"print(y1.numpy(), y2.numpy())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.29277328 -1.2625693 ]] [[ 0.29277328 -1.2625693 ]]\n"
]
}
],
"source": [
"model_c, model_d = cut_model(model, \"fc3\")\n",
"\n",
"x = tf.constant([[1., 2.]])\n",
"y1 = model(x)\n",
"y2 = model_d(model_c(x))\n",
"print(y1.numpy(), y2.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Converting to graphs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we are willing to leave the world of Keras `Model`s and accept a decomposition of the model into two graphs that are not themselves `Model` objects, we can approach the problem as follows."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"model = build_model()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we use `tf.function` to convert `model` to a callable function and then we let TF construct its graph via `get_concrete_function`."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"model_function = tf.function(lambda x: model(x))\n",
"model_function = model_function.get_concrete_function(\n",
" tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before we can use the graph, we need to convert variables to constants. Keras leaves variables in the model as ressources that need to be passed in using the `feed_dict` at inference time."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"model_function = convert_variables_to_constants_v2(model_function)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can use extract the graph."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph = model_function.graph\n",
"isinstance(graph, tf.Graph)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The difficulty is correlating the Keras layer output with the corresponding operation in the graph, because the names are subtly different. For example, the Keras layer output tensor is `fc1/BiasAdd_7:0`"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'fc2/BiasAdd_6:0' shape=(None, 8) dtype=float32>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.get_layer(\"fc2\").output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"While the corresponding operation in the graph is `sequential_3/fc1/BiasAdd`. Almost the same, but not quite."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['sequential_4/fc2/MatMul/ReadVariableOp/resource',\n",
" 'sequential_4/fc2/MatMul/ReadVariableOp',\n",
" 'sequential_4/fc2/MatMul',\n",
" 'sequential_4/fc2/BiasAdd/ReadVariableOp/resource',\n",
" 'sequential_4/fc2/BiasAdd/ReadVariableOp',\n",
" 'sequential_4/fc2/BiasAdd']"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[op.name for op in graph.get_operations() if \"fc2\" in op.name]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a slightly hacky way to find the names of the corresponding output tensors in the graph. No guarantees that it will generalise well to all graphs."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def get_graph_outputs(graph, model, layer_name):\n",
" output = model.get_layer(layer_name).output.name\n",
" # Remove the \":0\" part\n",
" output = output.split(\":\")[0]\n",
" \n",
" # The assumption is that the layer name has an optional \"_X\"\n",
" # compared to the graph op\n",
" op_found = False\n",
" for op in graph.get_operations():\n",
" prefix = model.name + \"/\"\n",
" if not op.name.startswith(prefix):\n",
" continue\n",
" \n",
" # Remove the prefix, e.g., \"sequential_1/\"\n",
" op_without_prefix = op.name[len(prefix):]\n",
" if output.startswith(op_without_prefix):\n",
" op_found = True\n",
" break\n",
" \n",
" if not op_found:\n",
" raise ValueError(\"Op not found in graph.\")\n",
" \n",
" # Now we have the op and its outputs\n",
" outputs = [out.name for out in op.outputs]\n",
" return outputs"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"outputs = get_graph_outputs(graph, model, \"fc1\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we use `tf.import_graph_def` to essentially create a copy of the graph with the desired inputs and outputs."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"@tf.function\n",
"def model_a(inputs):\n",
" inputs = tf.convert_to_tensor(inputs)\n",
" y = tf.import_graph_def(\n",
" graph.as_graph_def(),\n",
" input_map={graph.inputs[0].name: inputs},\n",
" return_elements=outputs,\n",
" )\n",
" return y[0]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"@tf.function\n",
"def model_b(inputs):\n",
" inputs = tf.convert_to_tensor(inputs)\n",
" y = tf.import_graph_def(\n",
" graph.as_graph_def(),\n",
" input_map={outputs[0]: inputs},\n",
" return_elements=[op.name for op in graph.outputs],\n",
" )\n",
" return y[0]"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1.0905744 1.1008097]] [[-1.0905744 1.1008097]]\n"
]
}
],
"source": [
"x = tf.constant([[1., 2.]])\n",
"y1 = model(x)\n",
"y2 = model_b(model_a(x))\n",
"print(y1.numpy(), y2.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we put all this into one function."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def cut_model(model, cut_layer_name):\n",
" model_function = tf.function(lambda x: model(x))\n",
" model_function = model_function.get_concrete_function(\n",
" tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)\n",
" )\n",
" model_function = convert_variables_to_constants_v2(model_function)\n",
" graph = model_function.graph\n",
"\n",
" outputs = get_graph_outputs(graph, model, cut_layer_name)\n",
" \n",
" @tf.function\n",
" def _model_a(inputs):\n",
" inputs = tf.convert_to_tensor(inputs)\n",
" y = tf.import_graph_def(\n",
" graph.as_graph_def(),\n",
" input_map={graph.inputs[0].name: inputs},\n",
" return_elements=outputs,\n",
" )\n",
" return y[0]\n",
" \n",
" @tf.function\n",
" def _model_b(inputs):\n",
" inputs = tf.convert_to_tensor(inputs)\n",
" y = tf.import_graph_def(\n",
" graph.as_graph_def(),\n",
" input_map={outputs[0]: inputs},\n",
" return_elements=[op.name for op in graph.outputs],\n",
" )\n",
" return y[0]\n",
" \n",
" return _model_a, _model_b"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.5226863 0.5934272]] [[0.5226863 0.5934272]]\n"
]
}
],
"source": [
"model = build_model()\n",
"model_a, model_b = cut_model(model, \"fc2\")\n",
"\n",
"x = tf.constant([[1., 2.]])\n",
"y1 = model(x)\n",
"y2 = model_b(model_a(x))\n",
"print(y1.numpy(), y2.numpy())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment