Skip to content

Instantly share code, notes, and snippets.

Created August 17, 2017 23:53
Show Gist options
  • Save namakemono/f4f273dbc63fc2174940415a9f689a6f to your computer and use it in GitHub Desktop.
Save namakemono/f4f273dbc63fc2174940415a9f689a6f to your computer and use it in GitHub Desktop.
Decomposable Attention with Keras.
from keras.layers import *
from keras.activations import softmax
from keras.models import Model
[1]. Parikh, Ankur P., et al. "A decomposable attention model for natural language inference." arXiv preprint arXiv:1606.01933 (2016).
def StaticEmbedding(embedding_matrix):
in_dim, out_dim = embedding_matrix.shape
embedding = Embedding(in_dim, out_dim, weights=[embedding_matrix], trainable=False)
return embedding
def unchanged_shape(input_shape):
return input_shape
def time_distributed(x, layers):
for l in layers:
x = TimeDistributed(l)(x)
return x
def align(input_1, input_2):
attention = Dot(axes=-1)([input_1, input_2])
w_att_1 = Lambda(lambda x: softmax(x, axis=1),
w_att_2 = Permute((2,1))(Lambda(lambda x: softmax(x, axis=2),
in1_aligned = Dot(axes=1)([w_att_1, input_1])
in2_aligned = Dot(axes=1)([w_att_2, input_2])
return in1_aligned, in2_aligned
def aggregate(x1, x2, num_class, dense_dim=300, dropout_rate=0.2, activation="relu"):
feat1 = concatenate(map(lambda l: l(x1), [GlobalAvgPool1D(), GlobalMaxPool1D()]))
feat2 = concatenate(map(lambda l: l(x2), [GlobalAvgPool1D(), GlobalMaxPool1D()]))
x = Concatenate()([feat1, feat2])
x = BatchNormalization()(x)
x = Dense(dense_dim, activation=activation)(x)
x = Dropout(dropout_rate)(x)
x = BatchNormalization()(x)
x = Dense(dense_dim, activation=activation)(x)
x = Dropout(dropout_rate)(x)
scores = Dense(num_class, activation='sigmoid')(x)
return scores
def build_model(embedding_matrix, num_class=1,
projection_dim=300, projection_hidden=0, projection_dropout=0.2,
compare_dim=500, compare_dropout=0.2,
dense_dim=300, dropout_rate=0.2,
lr=1e-3, activation='relu', maxlen=30):
q1 = Input(name='q1',shape=(maxlen,))
q2 = Input(name='q2',shape=(maxlen,))
# Embedding
encode = StaticEmbedding(embedding_matrix)
q1_embed = encode(q1)
q2_embed = encode(q2)
# Projection
projection_layers = []
if projection_hidden > 0:
Dense(projection_hidden, activation=activation),
Dense(projection_dim, activation=None),
q1_encoded = time_distributed(q1_embed, projection_layers)
q2_encoded = time_distributed(q2_embed, projection_layers)
# Attention
q1_aligned, q2_aligned = align(q1_encoded, q2_encoded)
# Compare
q1_combined = concatenate([q1_encoded, q2_aligned])
q2_combined = concatenate([q2_encoded, q1_aligned])
compare_layers = [
Dense(compare_dim, activation=activation),
Dense(compare_dim, activation=activation),
q1_compare = time_distributed(q1_combined, compare_layers)
q2_compare = time_distributed(q2_combined, compare_layers)
# Aggregate
scores = aggregate(q1_compare, q2_compare, num_class)
model = Model(inputs=[q1, q2], outputs=scores)
return model
if __name__ == "__main__":
import numpy as np
model = build_model(embedding_matrix=np.zeros((30, 20)), projection_hidden=200)
print model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment