Skip to content

Instantly share code, notes, and snippets.

@sborquez
Last active November 1, 2020 15:08
Show Gist options
  • Save sborquez/b14bc5aa84c96692975a0f21ea1e57eb to your computer and use it in GitHub Desktop.
Save sborquez/b14bc5aa84c96692975a0f21ea1e57eb to your computer and use it in GitHub Desktop.
Split a trained tf_keras or keras model for predictions separatly.
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
def split_model(model, split_layer_name=None, split_layer_index=None):
"""
Split a trained model for predictions.
If `split_layer_name` and `split_layer_index` are both provided, `split_layer_index` will take precedence.
Indices are based on order of horizontal graph traversal (bottom-up).
Parameters
----------
model : `tensorflow.keras.Model`
Source trained model.
split_layer_name : `str` or `None`
Name of layer to split model.
split_layer_index : `int` or `None`
Telescope type.
Returns
=======
`tuple` of `tensorflow.keras.Model`
Encoder and Regressor models with source model's weigths.
"""
# First model
split_layer_index = split_layer_index or model.layers.index(model.get_layer(split_layer_name))
encoder = Model(
model.input,
model.get_layer(index=split_layer_index).output
)
latent_variables_shape = encoder.output.shape[1:]
# Seconad model
x = regressor_input = Input(shape=latent_variables_shape)
for layer in model.layers[split_layer_index + 1:]:
x = layer(x)
regressor = Model(regressor_input, x)
## copy weights
for layer in regressor.layers[1:]:
layer.set_weights(
model.get_layer(name=layer.name).get_weights()
)
return encoder, regressor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment