Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Last active August 29, 2017 17:55
Show Gist options
  • Save bzamecnik/a4bf0d70ea86c54617aa06aaf6e41615 to your computer and use it in GitHub Desktop.
Save bzamecnik/a4bf0d70ea86c54617aa06aaf6e41615 to your computer and use it in GitHub Desktop.
Demultiplexing inputs within Keras layers
"""
In this example we show how to select and separately process multiple input
features within Keras layers.
Let's say we have a model with two categorical features and we can to embed
or one-hot encode each one separately. Normally in the Functional API we
would make two Input layers, one for each feature, then connect Embedding
to each, merge them and then add some more Dense/LSTM/... layers. In this
case we need to provide the model.predict() with a list of input arrays
instead of just one. It becomes a bit cumbersome if you need to index and
slice all of these arrays at once. Is it possible to provide just single
input and demultiplex it in the model? The answer is yes!
We can just use multipl Lambda layers with custom selection.
Let's have two features in the last dimension:
input_one = Lambda(lambda i: i[..., 0])(input_all)
input_two = Lambda(lambda i: i[..., 1])(input_all)
Simple as that!
It works only if the inputs have the same type. In case you had some integer
indexes and floating point values, you'd probably need multiple inputs or
cast the types within some Lambda layer.
See also the tutorial how to do one-hot encoding in a Keras layer:
https://gist.github.com/bzamecnik/a33052ec46ee7efeb217856d98a4fb5f
"""
from keras import backend as K
from keras.models import Model
from keras.layers import Embedding, Input, Lambda, merge
frame_size = 10
feature_count = 2
class_counts = [200, 50]
embedding_sizes = [50, 10]
input_shape = (frame_size, feature_count)
features_input = Input(shape=input_shape, dtype='int32')
def select_and_encode_feature(index, nb_classes, embedding_size):
input_int = Lambda(lambda i: i[..., index])(features_input)
input_encoded = Embedding(
input_dim=nb_classes,
output_dim=embedding_size,
input_length=frame_size)(input_int)
return input_encoded
features_encoded = [select_and_encode_feature(i, class_count, embedding_size)
for i, (class_count, embedding_size)
in enumerate(zip(class_counts, embedding_sizes))]
embeddings_merged = merge(features_encoded, mode='concat')
model = Model(features_input, embeddings_merged)
x = np.dstack([np.random.randint(low=0, high=c, size=frame_size) for c in class_counts])
assert model.predict(x).shape == (1, frame_size, sum(embedding_sizes))
@bzamecnik
Copy link
Author

Probably it will fail while loading the saved model.

Argument index in the lambda passed to Lambda must be stored, eg. using default arguments.

Example:

def extract(index):
    # store input index into the closure using default arguments
    def func(inputs, index=index):
        return inputs[:, index]
    return func
separate_inputs = [Lambda(extract(i))(input) for i in range(input_count)]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment