Instantly share code, notes, and snippets.

What would you like to do?
keras soft attention
def make_safe(x):
return K.clip(x, K.common._EPSILON, 1.0 - K.common._EPSILON)
class ProbabilityTensor(Wrapper):
""" function for turning 3d tensor to 2d probability matrix """
def __init__(self, dense_function=None, *args, **kwargs):
self.supports_masking = True
self.input_spec = [InputSpec(ndim=3)]
### BIG NOTE. I call TimeDistributed "Distribute".
layer = Distribute(dense_function) or Distribute(Dense(1, name='ptensor_func'))
super(ProbabilityTensor, self).__init__(layer, *args, **kwargs)
def build(self, input_shape):
assert len(input_shape) == 3
self.input_spec = [InputSpec(shape=input_shape)]
if K._BACKEND == 'tensorflow':
if not input_shape[1]:
raise Exception('When using TensorFlow, you should define '
'explicitly the number of timesteps of '
'your sequences.\n'
'If your first layer is an Embedding, '
'make sure to pass it an "input_length" '
'argument. Otherwise, make sure '
'the first layer has '
'an "input_shape" or "batch_input_shape" '
'argument, including the time axis.')
if not self.layer.built:
self.layer.built = True
super(ProbabilityTensor, self).build()
def get_output_shape_for(self, input_shape):
# b,n,f -> b,n
# s.t. \sum_n n = 1
if isinstance(input_shape, (list,tuple)) and not isinstance(input_shape[0], int):
input_shape = input_shape[0]
return (input_shape[0], input_shape[1])
def squash_mask(self, mask):
if K.ndim(mask) == 2:
return mask
elif K.ndim(mask) == 3:
return K.any(mask, axis=-1)
def compute_mask(self, x, mask=None):
if mask is None:
return None
return self.squash_mask(mask)
def call(self, x, mask=None):
energy = K.squeeze(self.layer(x), 2)
p_matrix = softmax(energy)
if mask is not None:
mask = self.squash_mask(mask)
p_matrix = make_safe(p_matrix * mask)
p_matrix = (p_matrix / K.sum(p_matrix, axis=-1, keepdims=True))*mask
return p_matrix
def get_config(self):
config = {}
base_config = super(ProbabilityTensor, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class SoftAttention(ProbabilityTensor):
def get_output_shape_for(self, input_shape):
# b,n,f -> b,f where f is weighted features summed across n
return (input_shape[0], input_shape[2])
def compute_mask(self, x, mask=None):
if mask is None or mask.ndim==2:
return None
raise Exception("Unexpected situation")
def call(self, x, mask=None):
# b,n,f -> b,f via b,n broadcasted
p_vectors = K.expand_dims(super(SoftAttention, self).call(x, mask), 2)
expanded_p = K.repeat_elements(p_vectors, K.shape(x)[2], axis=2)
return K.sum(expanded_p * x, axis=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment