Skip to content

Instantly share code, notes, and snippets.

@fchollet
Last active May 23, 2019 11:14
Show Gist options
  • Save fchollet/314085fffa200de9c3da to your computer and use it in GitHub Desktop.
Save fchollet/314085fffa200de9c3da to your computer and use it in GitHub Desktop.
'''Functional Keras is a more functional replacement for the Graph API.
'''
###################
# 2 LSTM branches #
###################
a = Input(input_shape=(10, 32)) # output is a TF/TH placeholder, augmented with Keras attributes
b = Input(input_shape=(10, 32))
encoded_a = LSTM(32)(a) # output is a TF/TH tensor
encoded_b = LSTM(32)(b)
merged = merge([encoded_a, encoded_b], mode='concat')
decoded = RepeatVector(10)(merged)
decoded = LSTM(32, return_sequences=True)(decoded)
# this is a fully-featured Keras model, will all the goodies that come with those.
# this is made possible by Keras topology information stored in the tensors.
model = Model(input=[a, b], output=[decoded])
model.compile(optimizer=Adam(), loss='mse')
model.fit([x1, x2], y)
################
# Shared layer #
################
shared_lstm = LSTM(32)
a = Input(input_shape=(10, 32))
b = Input(input_shape=(10, 32))
encoded_a = shared_lstm(a)
encoded_b = shared_lstm(b)
merged = merge([encoded_a, encoded_b], mode='concat')
decoded = RepeatVector(10)(merged)
decoded = LSTM(32, return_sequences=True)(decoded)
##############################
# Insertion of arbitrary ops #
##############################
# NOTE: cannot do a = tf.sigmoid(a), because although 'a' is a valid tf tensor,
# it is 'augmented' with data that allows Keras to keep track of previous operations
# (thus making it possible to train a model)...
a = Input(input_shape=(10, 32))
a = Lambda(tf.sigmoid)(a)
model = Model(input=[a, b], output=[decoder])
model.compile(optimizer=Adam(), loss='mse')
model.fit([x1, x2], y)
@fchollet
Copy link
Author

Masking : How are the masks of sequences passed around among different layers?

Like they were before. Some layers can generate a mask based on their input tensor and the previous mask. The mask is then propagated forward. If a layer that does not supports masking receives a non-None mask, it raises an error.

Importantly the new approach is more general than the previous one, so it will be possible for a multi-input layer to handle masking.

How it works in practice:

a = Input(shape)

# This creates a node in a graph linking a to b.
# the mask generated by Masking is stored inside the node.
b = Masking()(a)

# the lstm retrieves the node that b came from, and reads the mask from there
c = LSTM(32)(b)

The Graph model can be used as a query-able data structure when copying weights from a model to another model of different config. Will this be possible in the new api?

Yes. This is an important feature. You will still be able to iterate over the layers in a graph and query a layer by name.

a = Input(shape)
b = Dense(32, name='my_dense')(a)
c = Dense(32, name='output')(b)

model = Model(a, c)

# list of all layers in order of horizontal graph traversal.
# So for a sequential model it's just the ordered list of layers, starting with the input layer
model.layers

first_dense_instance = model.get_layer(name='my_dense')
first_dense_instance = model.get_layer(index=0)

@GregorySenay
Copy link

Very interesting new API and less verbose, more readable and avoid a lot of input=X name=X, I like it!
What about the access of an intermediary layer like in a Siamese Network?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment