Skip to content

Instantly share code, notes, and snippets.

@Elfsong
Forked from namakemono/esim.py
Created May 18, 2019 02:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Elfsong/df223bcc1021ac95b08cf4fcd5859650 to your computer and use it in GitHub Desktop.
Save Elfsong/df223bcc1021ac95b08cf4fcd5859650 to your computer and use it in GitHub Desktop.
Implementation of ESIM(Enhanced LSTM for Natural Language Inference)
"""
Implementation of ESIM(Enhanced LSTM for Natural Language Inference)
https://arxiv.org/abs/1609.06038
"""
import numpy as np
from keras.layers import *
from keras.activations import softmax
from keras.models import Model
def StaticEmbedding(embedding_matrix):
in_dim, out_dim = embedding_matrix.shape
return Embedding(in_dim, out_dim, weights=[embedding_matrix], trainable=False)
def subtract(input_1, input_2):
minus_input_2 = Lambda(lambda x: -x)(input_2)
return add([input_1, minus_input_2])
def aggregate(input_1, input_2, num_dense=300, dropout_rate=0.5):
feat1 = concatenate([GlobalAvgPool1D()(input_1), GlobalMaxPool1D()(input_1)])
feat2 = concatenate([GlobalAvgPool1D()(input_2), GlobalMaxPool1D()(input_2)])
x = concatenate([feat1, feat2])
x = BatchNormalization()(x)
x = Dense(num_dense, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(dropout_rate)(x)
x = Dense(num_dense, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(dropout_rate)(x)
return x
def align(input_1, input_2):
attention = Dot(axes=-1)([input_1, input_2])
w_att_1 = Lambda(lambda x: softmax(x, axis=1))(attention)
w_att_2 = Permute((2,1))(Lambda(lambda x: softmax(x, axis=2))(attention))
in1_aligned = Dot(axes=1)([w_att_1, input_1])
in2_aligned = Dot(axes=1)([w_att_2, input_2])
return in1_aligned, in2_aligned
def build_model(embedding_matrix, num_class=1, max_length=30, lstm_dim=300):
q1 = Input(shape=(max_length,))
q2 = Input(shape=(max_length,))
# Embedding
embedding = StaticEmbedding(embedding_matrix)
q1_embed = BatchNormalization(axis=2)(embedding(q1))
q2_embed = BatchNormalization(axis=2)(embedding(q2))
# Encoding
encode = Bidirectional(LSTM(lstm_dim, return_sequences=True))
q1_encoded = encode(q1_embed)
q2_encoded = encode(q2_embed)
# Alignment
q1_aligned, q2_aligned = align(q1_encoded, q2_encoded)
# Compare
q1_combined = concatenate([q1_encoded, q2_aligned, subtract(q1_encoded, q2_aligned), multiply([q1_encoded, q2_aligned])])
q2_combined = concatenate([q2_encoded, q1_aligned, subtract(q2_encoded, q1_aligned), multiply([q2_encoded, q1_aligned])])
compare = Bidirectional(LSTM(lstm_dim, return_sequences=True))
q1_compare = compare(q1_combined)
q2_compare = compare(q2_combined)
# Aggregate
x = aggregate(q1_compare, q2_compare)
x = Dense(num_class, activation='sigmoid')(x)
return Model(inputs=[q1, q2], outputs=x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment