Skip to content

Instantly share code, notes, and snippets.

@hanxiao
Last active May 2, 2018 04:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hanxiao/1879d361da3b06abbf6f47ff937a7831 to your computer and use it in GitHub Desktop.
Save hanxiao/1879d361da3b06abbf6f47ff937a7831 to your computer and use it in GitHub Desktop.
MLSTM to match passage encodes with question encodes
# Python 3.6 + TF 1.6
# Han Xiao (artex.xh@gmail.com)
import tensorflow as tf
import tensorflow.contrib as tc
#### Usage: matching passage encodes with question encodes
ml = MatchLSTMLayer(hidden_size=16,
control_gate=False,
pooling_window=5, # do pooling otherwise attention on long passage will give OOM
name='demo',
act_fn=tf.nn.relu,
attend_hidden_size=16)
ml.match(input_encodes=p_encodes,
attended_encodes=q_encodes,
input_length=p_length,
input_mask=p_mask,
attended_mask=q_mask)
#### Details
def attend_pooling(pooling_vectors, ref_vector, hidden_size, scope=None,
pooling_mask=None, activation_fn=tf.tanh, output_logit=False):
"""
Applies attend pooling to a set of vectors according to a reference vector.
Args:
pooling_vectors: the vectors to pool in B x T x D size
ref_vector: the reference vector in B x D size, at a single time t, D can be different than pooling_vectors
hidden_size: the hidden size for attention function
scope: score name
Returns:
the pooled vector in B x D size
"""
with tf.variable_scope(scope or 'attend_pooling'):
# pooling vectors must be B x T x D size
assert pooling_vectors.get_shape().ndims == 3
# ref_vector must be B x D size, at a single time t
assert ref_vector.get_shape().ndims == 2
if pooling_mask is not None:
# pooling_mask must be B x T size
assert pooling_mask.get_shape().ndims == 2
U = activation_fn(tc.layers.fully_connected(pooling_vectors,
num_outputs=hidden_size,
activation_fn=None)
+ tf.expand_dims(tc.layers.fully_connected(ref_vector,
num_outputs=hidden_size,
activation_fn=None), axis=1))
logits = tc.layers.fully_connected(U, num_outputs=1, activation_fn=None)
if pooling_mask is not None:
logits -= tf.expand_dims(1.0 - pooling_mask, axis=2) * 1e30
scores = tf.nn.softmax(logits, 1)
pooled_vector = tf.reduce_sum(pooling_vectors * scores, axis=1)
# pooled vector is B x D size, score is B x L size
return pooled_vector, logits if output_logit else scores
class MatchLSTMAttnCell(tc.rnn.LSTMCell):
"""
Implements the Match-LSTM attention cell
"""
def __init__(self, num_units, context_to_attend, control_gate, pooling_window, attend_mask, act_fn,
attend_hidden_size):
super().__init__(num_units, state_is_tuple=True)
if pooling_window:
self.context_to_attend = tf.nn.pool(context_to_attend,
window_shape=[pooling_window],
strides=[pooling_window],
pooling_type='MAX', padding='SAME',
name='max_pool_on_context')
# do the same max-pooling for mask
# mask is usually BxT. To do pooling, we need to first expand it to BxTx1
# then squeeze it back to BxT
self.attend_mask = tf.squeeze(tf.nn.pool(tf.expand_dims(attend_mask, -1),
window_shape=[pooling_window],
strides=[pooling_window],
pooling_type='MAX', padding='SAME',
name='max_pool_on_context_mask'), -1)
else:
self.context_to_attend = context_to_attend
self.attend_mask = attend_mask
self.control_gate = control_gate
self.act_fn = act_fn
self.attend_hidden_size = attend_hidden_size
def __call__(self, inputs, state, scope=None):
(c_prev, h_prev) = state
with tf.variable_scope(scope or type(self).__name__):
ref_vector = tf.concat([inputs, h_prev], -1)
attended_context, scores = attend_pooling(self.context_to_attend,
ref_vector,
self.attend_hidden_size,
pooling_mask=self.attend_mask,
activation_fn=self.act_fn)
new_inputs = tf.concat([inputs, attended_context,
inputs - attended_context,
inputs * attended_context],
-1)
if self.control_gate:
# modified by adding another gate to the input
control_gate = tc.layers.fully_connected(new_inputs,
num_outputs=self.output_size * 8,
activation_fn=tf.nn.sigmoid)
new_inputs *= control_gate
return super().__call__(new_inputs, state, scope)
class MatchLSTMLayer:
"""
Implements the Match-LSTM layer, which attend to the question dynamically in a LSTM fashion.
"""
def __init__(self, hidden_size: int, control_gate: bool,
pooling_window: int, name: str, act_fn, attend_hidden_size: int):
self.hidden_size = hidden_size
self.output_size = hidden_size * 2 # bi-directional
self.act_fn = act_fn
self.name = name
self.control_gate = control_gate
self.pooling_window = pooling_window # useful when context is too long.
self.attend_hidden_size = attend_hidden_size
def match(self, input_encodes, attended_encodes, input_length, input_mask, attended_mask):
"""
Match the passage_encodes with question_encodes using Match-LSTM algorithm
"""
with tf.variable_scope(self.name):
cell_fw = MatchLSTMAttnCell(self.hidden_size, attended_encodes, self.control_gate, self.pooling_window,
attended_mask, self.act_fn, self.attend_hidden_size)
cell_bw = MatchLSTMAttnCell(self.hidden_size, attended_encodes, self.control_gate, self.pooling_window,
attended_mask, self.act_fn, self.attend_hidden_size)
outputs, state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw,
inputs=input_encodes,
sequence_length=input_length,
dtype=tf.float32)
match_outputs = tf.concat(outputs, 2)
state_fw, state_bw = state
c_fw, h_fw = state_fw
c_bw, h_bw = state_bw
match_state = tf.concat([h_fw, h_bw], 1)
return match_outputs, match_state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment