Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Implementation of ESIM(Enhanced LSTM for Natural Language Inference)
"""
Implementation of ESIM(Enhanced LSTM for Natural Language Inference)
https://arxiv.org/abs/1609.06038
"""
import numpy as np
from keras.layers import *
from keras.activations import softmax
from keras.models import Model
def StaticEmbedding(embedding_matrix):
in_dim, out_dim = embedding_matrix.shape
return Embedding(in_dim, out_dim, weights=[embedding_matrix], trainable=False)
def subtract(input_1, input_2):
minus_input_2 = Lambda(lambda x: -x)(input_2)
return add([input_1, minus_input_2])
def aggregate(input_1, input_2, num_dense=300, dropout_rate=0.5):
feat1 = concatenate([GlobalAvgPool1D()(input_1), GlobalMaxPool1D()(input_1)])
feat2 = concatenate([GlobalAvgPool1D()(input_2), GlobalMaxPool1D()(input_2)])
x = concatenate([feat1, feat2])
x = BatchNormalization()(x)
x = Dense(num_dense, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(dropout_rate)(x)
x = Dense(num_dense, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(dropout_rate)(x)
return x
def align(input_1, input_2):
attention = Dot(axes=-1)([input_1, input_2])
w_att_1 = Lambda(lambda x: softmax(x, axis=1))(attention)
w_att_2 = Permute((2,1))(Lambda(lambda x: softmax(x, axis=2))(attention))
in1_aligned = Dot(axes=1)([w_att_1, input_1])
in2_aligned = Dot(axes=1)([w_att_2, input_2])
return in1_aligned, in2_aligned
def build_model(embedding_matrix, num_class=1, max_length=30, lstm_dim=300):
q1 = Input(shape=(max_length,))
q2 = Input(shape=(max_length,))
# Embedding
embedding = StaticEmbedding(embedding_matrix)
q1_embed = BatchNormalization(axis=2)(embedding(q1))
q2_embed = BatchNormalization(axis=2)(embedding(q2))
# Encoding
encode = Bidirectional(LSTM(lstm_dim, return_sequences=True))
q1_encoded = encode(q1_embed)
q2_encoded = encode(q2_embed)
# Alignment
q1_aligned, q2_aligned = align(q1_encoded, q2_encoded)
# Compare
q1_combined = concatenate([q1_encoded, q2_aligned, subtract(q1_encoded, q2_aligned), multiply([q1_encoded, q2_aligned])])
q2_combined = concatenate([q2_encoded, q1_aligned, subtract(q2_encoded, q1_aligned), multiply([q2_encoded, q1_aligned])])
compare = Bidirectional(LSTM(lstm_dim, return_sequences=True))
q1_compare = compare(q1_combined)
q2_compare = compare(q2_combined)
# Aggregate
x = aggregate(q1_compare, q2_compare)
x = Dense(num_class, activation='sigmoid')(x)
return Model(inputs=[q1, q2], outputs=x)
@nguyenhuyx

This comment has been minimized.

Copy link

commented Oct 4, 2018

how about the masking.
I see your code, but it does not have masking in it.

@Elfsong

This comment has been minimized.

Copy link

commented May 18, 2019

May I ask what is embedding_matrix? How to generate it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.