Last active
November 1, 2020 15:08
-
-
Save sborquez/b14bc5aa84c96692975a0f21ea1e57eb to your computer and use it in GitHub Desktop.
Split a trained tf_keras or keras model for predictions separatly.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.keras.models import Model | |
from tensorflow.keras.layers import Input | |
def split_model(model, split_layer_name=None, split_layer_index=None): | |
""" | |
Split a trained model for predictions. | |
If `split_layer_name` and `split_layer_index` are both provided, `split_layer_index` will take precedence. | |
Indices are based on order of horizontal graph traversal (bottom-up). | |
Parameters | |
---------- | |
model : `tensorflow.keras.Model` | |
Source trained model. | |
split_layer_name : `str` or `None` | |
Name of layer to split model. | |
split_layer_index : `int` or `None` | |
Telescope type. | |
Returns | |
======= | |
`tuple` of `tensorflow.keras.Model` | |
Encoder and Regressor models with source model's weigths. | |
""" | |
# First model | |
split_layer_index = split_layer_index or model.layers.index(model.get_layer(split_layer_name)) | |
encoder = Model( | |
model.input, | |
model.get_layer(index=split_layer_index).output | |
) | |
latent_variables_shape = encoder.output.shape[1:] | |
# Seconad model | |
x = regressor_input = Input(shape=latent_variables_shape) | |
for layer in model.layers[split_layer_index + 1:]: | |
x = layer(x) | |
regressor = Model(regressor_input, x) | |
## copy weights | |
for layer in regressor.layers[1:]: | |
layer.set_weights( | |
model.get_layer(name=layer.name).get_weights() | |
) | |
return encoder, regressor |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment