Skip to content

Instantly share code, notes, and snippets.

@DSamuylov
Last active November 11, 2018 19:16
Show Gist options
  • Save DSamuylov/1f3d42478be4e277f776783f215816cf to your computer and use it in GitHub Desktop.
Save DSamuylov/1f3d42478be4e277f776783f215816cf to your computer and use it in GitHub Desktop.
Failure in saving weights of a model that has a sub-model that shares weights of another model.
import tensorflow.keras as keras
# import keras
keras.backend.clear_session()
# Model that we are goint to share.
x = keras.layers.Input(shape=[None, None, 3])
y = keras.layers.Conv2D(filters=4, kernel_size=(3, 3))(x)
model = keras.models.Model(inputs=x, outputs=y)
print("# Save `model`:")
model.save_weights(filepath="test.h5")
# IT WORKS!
# We importe this model as a part of a more complex model. Moreover we share its weights among multiple inputs:
x1 = keras.layers.Input(shape=[None, None, 3])
y1 = model(x1)
model_to_be_shared = keras.models.Model(inputs=x1, outputs=y1)
x2 = keras.layers.Input(shape=[None, None, 3])
y2 = model_to_be_shared(x2)
model_with_2_inputs = keras.models.Model(inputs = [x1, x2], outputs=keras.layers.Add()([y1, y2]))
print("# Save `model_with_2_inputs`:")
model_with_2_inputs.save_weights(filepath="test.h5")
# IT WORKS!
# Finally, we wrap the model as a part of a more complex model:
x1 = keras.layers.Input(shape=[None, None, 3])
x2 = keras.layers.Input(shape=[None, None, 3])
y = model_with_2_inputs([x1, x2])
model_final = keras.models.Model(inputs=[x1, x2], outputs=y)
print("# Save `model_final`:")
model_final.save_weights(filepath="test.h5")
# IT FAILS!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment