Created
December 8, 2019 09:37
-
-
Save hsahovic/365e6955f950e06821e8b9e4869bc78c to your computer and use it in GitHub Desktop.
This snippet was used to recover a proper keras model from a saved model which contained a submodel (ie., one of its layers was actually another model), in order to apply model optimization a posteriori (quantization, pruning). It can be extended to handle more type of layers.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.keras.models import Sequential | |
import tensorflow.keras.layers as keras_layers | |
clone = Sequential() | |
# Here, layer[1] is a model. | |
layers = [model.layers[0]] + model.layers[1].layers + model.layers[2:] | |
for layer in layers: | |
if str(type(layer)).endswith("InputLayer'>"): | |
clone.add(keras_layers.Input(shape=(192,192,3))) | |
elif str(type(layer)).endswith("Conv2D'>"): | |
l = layer | |
clone.add(keras_layers.Conv2D(**layer.get_config())) | |
clone.layers[-1].set_weights(layer.get_weights()) | |
elif str(type(layer)).endswith("MaxPooling2D'>"): | |
l = layer | |
clone.add(keras_layers.MaxPooling2D(**layer.get_config())) | |
clone.layers[-1].set_weights(layer.get_weights()) | |
elif str(type(layer)).endswith("Dense'>"): | |
l = layer | |
clone.add(keras_layers.Dense(**layer.get_config())) | |
clone.layers[-1].set_weights(layer.get_weights()) | |
elif str(type(layer)).endswith("Flatten'>"): | |
l = layer | |
clone.add(keras_layers.Flatten(**layer.get_config())) | |
clone.layers[-1].set_weights(layer.get_weights()) | |
else: | |
print(str(type(layer))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment