Last active
August 29, 2017 17:55
-
-
Save bzamecnik/a4bf0d70ea86c54617aa06aaf6e41615 to your computer and use it in GitHub Desktop.
Demultiplexing inputs within Keras layers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
In this example we show how to select and separately process multiple input | |
features within Keras layers. | |
Let's say we have a model with two categorical features and we can to embed | |
or one-hot encode each one separately. Normally in the Functional API we | |
would make two Input layers, one for each feature, then connect Embedding | |
to each, merge them and then add some more Dense/LSTM/... layers. In this | |
case we need to provide the model.predict() with a list of input arrays | |
instead of just one. It becomes a bit cumbersome if you need to index and | |
slice all of these arrays at once. Is it possible to provide just single | |
input and demultiplex it in the model? The answer is yes! | |
We can just use multipl Lambda layers with custom selection. | |
Let's have two features in the last dimension: | |
input_one = Lambda(lambda i: i[..., 0])(input_all) | |
input_two = Lambda(lambda i: i[..., 1])(input_all) | |
Simple as that! | |
It works only if the inputs have the same type. In case you had some integer | |
indexes and floating point values, you'd probably need multiple inputs or | |
cast the types within some Lambda layer. | |
See also the tutorial how to do one-hot encoding in a Keras layer: | |
https://gist.github.com/bzamecnik/a33052ec46ee7efeb217856d98a4fb5f | |
""" | |
from keras import backend as K | |
from keras.models import Model | |
from keras.layers import Embedding, Input, Lambda, merge | |
frame_size = 10 | |
feature_count = 2 | |
class_counts = [200, 50] | |
embedding_sizes = [50, 10] | |
input_shape = (frame_size, feature_count) | |
features_input = Input(shape=input_shape, dtype='int32') | |
def select_and_encode_feature(index, nb_classes, embedding_size): | |
input_int = Lambda(lambda i: i[..., index])(features_input) | |
input_encoded = Embedding( | |
input_dim=nb_classes, | |
output_dim=embedding_size, | |
input_length=frame_size)(input_int) | |
return input_encoded | |
features_encoded = [select_and_encode_feature(i, class_count, embedding_size) | |
for i, (class_count, embedding_size) | |
in enumerate(zip(class_counts, embedding_sizes))] | |
embeddings_merged = merge(features_encoded, mode='concat') | |
model = Model(features_input, embeddings_merged) | |
x = np.dstack([np.random.randint(low=0, high=c, size=frame_size) for c in class_counts]) | |
assert model.predict(x).shape == (1, frame_size, sum(embedding_sizes)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Probably it will fail while loading the saved model.
Argument
index
in the lambda passed toLambda
must be stored, eg. using default arguments.Example: