Created
July 9, 2018 21:02
-
-
Save kaeflint/7ce9159b65a2944c048fd30f0569cfc5 to your computer and use it in GitHub Desktop.
Attention Decoder (TF & Keras)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from keras import backend as K | |
from keras import regularizers, constraints, initializers, activations | |
from keras.layers.recurrent import Recurrent, _time_distributed_dense | |
from keras.engine import InputSpec | |
tfPrint = lambda d, T: tf.Print(input_=T, data=[T, tf.shape(T)], message=d) | |
class AttentionDecoder(Recurrent): | |
def __init__(self, units, output_dim, | |
activation='tanh', | |
return_probabilities=False, | |
name='AttentionDecoder', | |
kernel_initializer='glorot_uniform', | |
recurrent_initializer='orthogonal', | |
bias_initializer='zeros', | |
kernel_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
bias_constraint=None, | |
**kwargs): | |
""" | |
Implements an AttentionDecoder that takes in a sequence encoded by an | |
encoder and outputs the decoded states | |
:param units: dimension of the hidden state and the attention matrices | |
:param output_dim: the number of labels in the output space | |
references: | |
Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. | |
"Neural machine translation by jointly learning to align and translate." | |
arXiv preprint arXiv:1409.0473 (2014). | |
""" | |
self.units = units | |
self.output_dim = output_dim | |
self.return_probabilities = return_probabilities | |
self.activation = activations.get(activation) | |
self.kernel_initializer = initializers.get(kernel_initializer) | |
self.recurrent_initializer = initializers.get(recurrent_initializer) | |
self.bias_initializer = initializers.get(bias_initializer) | |
self.kernel_regularizer = regularizers.get(kernel_regularizer) | |
self.recurrent_regularizer = regularizers.get(kernel_regularizer) | |
self.bias_regularizer = regularizers.get(bias_regularizer) | |
self.activity_regularizer = regularizers.get(activity_regularizer) | |
self.kernel_constraint = constraints.get(kernel_constraint) | |
self.recurrent_constraint = constraints.get(kernel_constraint) | |
self.bias_constraint = constraints.get(bias_constraint) | |
super(AttentionDecoder, self).__init__(**kwargs) | |
self.name = name | |
self.return_sequences = True # must return sequences | |
def build(self, input_shape): | |
""" | |
See Appendix 2 of Bahdanau 2014, arXiv:1409.0473 | |
for model details that correspond to the matrices here. | |
""" | |
self.batch_size, self.timesteps, self.input_dim = input_shape | |
if self.stateful: | |
super(AttentionDecoder, self).reset_states() | |
self.states = [None, None] # y, s | |
""" | |
Matrices for creating the context vector | |
""" | |
self.V_a = self.add_weight(shape=(self.units,), | |
name='V_a', | |
initializer=self.kernel_initializer, | |
regularizer=self.kernel_regularizer, | |
constraint=self.kernel_constraint) | |
self.W_a = self.add_weight(shape=(self.units, self.units), | |
name='W_a', | |
initializer=self.kernel_initializer, | |
regularizer=self.kernel_regularizer, | |
constraint=self.kernel_constraint) | |
self.U_a = self.add_weight(shape=(self.input_dim, self.units), | |
name='U_a', | |
initializer=self.kernel_initializer, | |
regularizer=self.kernel_regularizer, | |
constraint=self.kernel_constraint) | |
self.b_a = self.add_weight(shape=(self.units,), | |
name='b_a', | |
initializer=self.bias_initializer, | |
regularizer=self.bias_regularizer, | |
constraint=self.bias_constraint) | |
""" | |
Matrices for the r (reset) gate | |
""" | |
self.C_r = self.add_weight(shape=(self.input_dim, self.units), | |
name='C_r', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.U_r = self.add_weight(shape=(self.units, self.units), | |
name='U_r', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.W_r = self.add_weight(shape=(self.output_dim, self.units), | |
name='W_r', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.b_r = self.add_weight(shape=(self.units, ), | |
name='b_r', | |
initializer=self.bias_initializer, | |
regularizer=self.bias_regularizer, | |
constraint=self.bias_constraint) | |
""" | |
Matrices for the z (update) gate | |
""" | |
self.C_z = self.add_weight(shape=(self.input_dim, self.units), | |
name='C_z', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.U_z = self.add_weight(shape=(self.units, self.units), | |
name='U_z', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.W_z = self.add_weight(shape=(self.output_dim, self.units), | |
name='W_z', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.b_z = self.add_weight(shape=(self.units, ), | |
name='b_z', | |
initializer=self.bias_initializer, | |
regularizer=self.bias_regularizer, | |
constraint=self.bias_constraint) | |
""" | |
Matrices for the proposal | |
""" | |
self.C_p = self.add_weight(shape=(self.input_dim, self.units), | |
name='C_p', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.U_p = self.add_weight(shape=(self.units, self.units), | |
name='U_p', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.W_p = self.add_weight(shape=(self.output_dim, self.units), | |
name='W_p', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.b_p = self.add_weight(shape=(self.units, ), | |
name='b_p', | |
initializer=self.bias_initializer, | |
regularizer=self.bias_regularizer, | |
constraint=self.bias_constraint) | |
""" | |
Matrices for making the final prediction vector | |
""" | |
self.C_o = self.add_weight(shape=(self.input_dim, self.output_dim), | |
name='C_o', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.U_o = self.add_weight(shape=(self.units, self.output_dim), | |
name='U_o', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.W_o = self.add_weight(shape=(self.output_dim, self.output_dim), | |
name='W_o', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.b_o = self.add_weight(shape=(self.output_dim, ), | |
name='b_o', | |
initializer=self.bias_initializer, | |
regularizer=self.bias_regularizer, | |
constraint=self.bias_constraint) | |
# For creating the initial state: | |
self.W_s = self.add_weight(shape=(self.input_dim, self.units), | |
name='W_s', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
self.input_spec = [ | |
InputSpec(shape=(self.batch_size, self.timesteps, self.input_dim))] | |
self.built = True | |
def call(self, x): | |
# store the whole sequence so we can "attend" to it at each timestep | |
self.x_seq = x | |
# apply the a dense layer over the time dimension of the sequence | |
# do it here because it doesn't depend on any previous steps | |
# thefore we can save computation time: | |
self._uxpb = _time_distributed_dense(self.x_seq, self.U_a, b=self.b_a, | |
input_dim=self.input_dim, | |
timesteps=self.timesteps, | |
output_dim=self.units) | |
return super(AttentionDecoder, self).call(x) | |
def get_initial_state(self, inputs): | |
# apply the matrix on the first time step to get the initial s0. | |
s0 = activations.tanh(K.dot(inputs[:, 0], self.W_s)) | |
# from keras.layers.recurrent to initialize a vector of (batchsize, | |
# output_dim) | |
y0 = K.zeros_like(inputs) # (samples, timesteps, input_dims) | |
y0 = K.sum(y0, axis=(1, 2)) # (samples, ) | |
y0 = K.expand_dims(y0) # (samples, 1) | |
y0 = K.tile(y0, [1, self.output_dim]) | |
return [y0, s0] | |
def step(self, x, states): | |
ytm, stm = states | |
# repeat the hidden state to the length of the sequence | |
_stm = K.repeat(stm, self.timesteps) | |
# now multiplty the weight matrix with the repeated hidden state | |
_Wxstm = K.dot(_stm, self.W_a) | |
# calculate the attention probabilities | |
# this relates how much other timesteps contributed to this one. | |
et = K.dot(activations.tanh(_Wxstm + self._uxpb), | |
K.expand_dims(self.V_a)) | |
at = K.exp(et) | |
at_sum = K.sum(at, axis=1) | |
at_sum_repeated = K.repeat(at_sum, self.timesteps) | |
at /= at_sum_repeated # vector of size (batchsize, timesteps, 1) | |
# calculate the context vector | |
context = K.squeeze(K.batch_dot(at, self.x_seq, axes=1), axis=1) | |
# ~~~> calculate new hidden state | |
# first calculate the "r" gate: | |
rt = activations.sigmoid( | |
K.dot(ytm, self.W_r) | |
+ K.dot(stm, self.U_r) | |
+ K.dot(context, self.C_r) | |
+ self.b_r) | |
# now calculate the "z" gate | |
zt = activations.sigmoid( | |
K.dot(ytm, self.W_z) | |
+ K.dot(stm, self.U_z) | |
+ K.dot(context, self.C_z) | |
+ self.b_z) | |
# calculate the proposal hidden state: | |
s_tp = activations.tanh( | |
K.dot(ytm, self.W_p) | |
+ K.dot((rt * stm), self.U_p) | |
+ K.dot(context, self.C_p) | |
+ self.b_p) | |
# new hidden state: | |
st = (1-zt)*stm + zt * s_tp | |
yt = activations.softmax( | |
K.dot(ytm, self.W_o) | |
+ K.dot(stm, self.U_o) | |
+ K.dot(context, self.C_o) | |
+ self.b_o) | |
if self.return_probabilities: | |
return at, [yt, st] | |
else: | |
return yt, [yt, st] | |
def compute_output_shape(self, input_shape): | |
""" | |
For Keras internal compatability checking | |
""" | |
if self.return_probabilities: | |
return (None, self.timesteps, self.timesteps) | |
else: | |
return (None, self.timesteps, self.output_dim) | |
def get_config(self): | |
""" | |
For rebuilding models on load time. | |
""" | |
config = { | |
'output_dim': self.output_dim, | |
'units': self.units, | |
'return_probabilities': self.return_probabilities | |
} | |
base_config = super(AttentionDecoder, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment