Skip to content

Instantly share code, notes, and snippets.

@phisad
Created March 6, 2019 22:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save phisad/4d8b7fd86cb44f66132648d91052b1d8 to your computer and use it in GitHub Desktop.
Save phisad/4d8b7fd86cb44f66132648d91052b1d8 to your computer and use it in GitHub Desktop.
Keras disconnected graphs example when using multiple models
def test_connected_models(self):
input1 = Input(shape=(100,))
dense1 = Dense(1)(input1)
model1 = Model(input1, dense1)
input2 = Input(shape=(200,))
dense2 = Dense(2)(input2)
model2 = Model(input2, dense2)
# This will work, because there are no intermediate Inputs.
# The best solution for a complex graph is not to use intermediate Models
# but to use the Funtional API and only produce one model at the end
dense3 = Concatenate()([model1.layers[-1].output, model2.layers[-1].output])
model3 = Model(inputs=[model1.input, model2.input], outputs=dense3)
print(model3.summary())
def test_disconnected_graph(self):
input1 = Input(shape=(100,))
dense1 = Dense(1)(input1)
model1 = Model(input1, dense1)
input2 = Input(shape=(200,))
dense2 = Dense(2)(input2)
model2 = Model(input2, dense2)
input11 = Input(shape=(1,))
input12 = Input(shape=(2,))
dense3 = Concatenate()([input11, input12])
model3 = Model(inputs=[input11, input12], outputs=dense3)
print(model3.summary())
# We cannot use model3.output here, because the model3 graph is disconnected
# from this graph by defining its own inputs input11 and input12
model123 = model3.layers[-1].output
model123 = Model(inputs=[model1.input, model2.input], outputs=model123)
print(model123.summary())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment