Skip to content

Instantly share code, notes, and snippets.

@Elfsong Elfsong/ forked from namakemono/
Created May 18, 2019

What would you like to do?
Implementation of ESIM(Enhanced LSTM for Natural Language Inference)
Implementation of ESIM(Enhanced LSTM for Natural Language Inference)
import numpy as np
from keras.layers import *
from keras.activations import softmax
from keras.models import Model
def StaticEmbedding(embedding_matrix):
in_dim, out_dim = embedding_matrix.shape
return Embedding(in_dim, out_dim, weights=[embedding_matrix], trainable=False)
def subtract(input_1, input_2):
minus_input_2 = Lambda(lambda x: -x)(input_2)
return add([input_1, minus_input_2])
def aggregate(input_1, input_2, num_dense=300, dropout_rate=0.5):
feat1 = concatenate([GlobalAvgPool1D()(input_1), GlobalMaxPool1D()(input_1)])
feat2 = concatenate([GlobalAvgPool1D()(input_2), GlobalMaxPool1D()(input_2)])
x = concatenate([feat1, feat2])
x = BatchNormalization()(x)
x = Dense(num_dense, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(dropout_rate)(x)
x = Dense(num_dense, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(dropout_rate)(x)
return x
def align(input_1, input_2):
attention = Dot(axes=-1)([input_1, input_2])
w_att_1 = Lambda(lambda x: softmax(x, axis=1))(attention)
w_att_2 = Permute((2,1))(Lambda(lambda x: softmax(x, axis=2))(attention))
in1_aligned = Dot(axes=1)([w_att_1, input_1])
in2_aligned = Dot(axes=1)([w_att_2, input_2])
return in1_aligned, in2_aligned
def build_model(embedding_matrix, num_class=1, max_length=30, lstm_dim=300):
q1 = Input(shape=(max_length,))
q2 = Input(shape=(max_length,))
# Embedding
embedding = StaticEmbedding(embedding_matrix)
q1_embed = BatchNormalization(axis=2)(embedding(q1))
q2_embed = BatchNormalization(axis=2)(embedding(q2))
# Encoding
encode = Bidirectional(LSTM(lstm_dim, return_sequences=True))
q1_encoded = encode(q1_embed)
q2_encoded = encode(q2_embed)
# Alignment
q1_aligned, q2_aligned = align(q1_encoded, q2_encoded)
# Compare
q1_combined = concatenate([q1_encoded, q2_aligned, subtract(q1_encoded, q2_aligned), multiply([q1_encoded, q2_aligned])])
q2_combined = concatenate([q2_encoded, q1_aligned, subtract(q2_encoded, q1_aligned), multiply([q2_encoded, q1_aligned])])
compare = Bidirectional(LSTM(lstm_dim, return_sequences=True))
q1_compare = compare(q1_combined)
q2_compare = compare(q2_combined)
# Aggregate
x = aggregate(q1_compare, q2_compare)
x = Dense(num_class, activation='sigmoid')(x)
return Model(inputs=[q1, q2], outputs=x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.