Skip to content

Instantly share code, notes, and snippets.

@hsahovic
Created December 8, 2019 09:37
Show Gist options
  • Save hsahovic/365e6955f950e06821e8b9e4869bc78c to your computer and use it in GitHub Desktop.
Save hsahovic/365e6955f950e06821e8b9e4869bc78c to your computer and use it in GitHub Desktop.
This snippet was used to recover a proper keras model from a saved model which contained a submodel (ie., one of its layers was actually another model), in order to apply model optimization a posteriori (quantization, pruning). It can be extended to handle more type of layers.
from tensorflow.keras.models import Sequential
import tensorflow.keras.layers as keras_layers
clone = Sequential()
# Here, layer[1] is a model.
layers = [model.layers[0]] + model.layers[1].layers + model.layers[2:]
for layer in layers:
if str(type(layer)).endswith("InputLayer'>"):
clone.add(keras_layers.Input(shape=(192,192,3)))
elif str(type(layer)).endswith("Conv2D'>"):
l = layer
clone.add(keras_layers.Conv2D(**layer.get_config()))
clone.layers[-1].set_weights(layer.get_weights())
elif str(type(layer)).endswith("MaxPooling2D'>"):
l = layer
clone.add(keras_layers.MaxPooling2D(**layer.get_config()))
clone.layers[-1].set_weights(layer.get_weights())
elif str(type(layer)).endswith("Dense'>"):
l = layer
clone.add(keras_layers.Dense(**layer.get_config()))
clone.layers[-1].set_weights(layer.get_weights())
elif str(type(layer)).endswith("Flatten'>"):
l = layer
clone.add(keras_layers.Flatten(**layer.get_config()))
clone.layers[-1].set_weights(layer.get_weights())
else:
print(str(type(layer)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment