-
-
Save braingineer/27c6f26755794f6544d83dec2dd27bbb to your computer and use it in GitHub Desktop.
keras soft attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def make_safe(x): | |
return K.clip(x, K.common._EPSILON, 1.0 - K.common._EPSILON) | |
class ProbabilityTensor(Wrapper): | |
""" function for turning 3d tensor to 2d probability matrix """ | |
def __init__(self, dense_function=None, *args, **kwargs): | |
self.supports_masking = True | |
self.input_spec = [InputSpec(ndim=3)] | |
### BIG NOTE. I call TimeDistributed "Distribute". | |
layer = Distribute(dense_function) or Distribute(Dense(1, name='ptensor_func')) | |
super(ProbabilityTensor, self).__init__(layer, *args, **kwargs) | |
def build(self, input_shape): | |
assert len(input_shape) == 3 | |
self.input_spec = [InputSpec(shape=input_shape)] | |
if K._BACKEND == 'tensorflow': | |
if not input_shape[1]: | |
raise Exception('When using TensorFlow, you should define ' | |
'explicitly the number of timesteps of ' | |
'your sequences.\n' | |
'If your first layer is an Embedding, ' | |
'make sure to pass it an "input_length" ' | |
'argument. Otherwise, make sure ' | |
'the first layer has ' | |
'an "input_shape" or "batch_input_shape" ' | |
'argument, including the time axis.') | |
if not self.layer.built: | |
self.layer.build(input_shape) | |
self.layer.built = True | |
super(ProbabilityTensor, self).build() | |
def get_output_shape_for(self, input_shape): | |
# b,n,f -> b,n | |
# s.t. \sum_n n = 1 | |
if isinstance(input_shape, (list,tuple)) and not isinstance(input_shape[0], int): | |
input_shape = input_shape[0] | |
return (input_shape[0], input_shape[1]) | |
def squash_mask(self, mask): | |
if K.ndim(mask) == 2: | |
return mask | |
elif K.ndim(mask) == 3: | |
return K.any(mask, axis=-1) | |
def compute_mask(self, x, mask=None): | |
if mask is None: | |
return None | |
return self.squash_mask(mask) | |
def call(self, x, mask=None): | |
energy = K.squeeze(self.layer(x), 2) | |
p_matrix = softmax(energy) | |
if mask is not None: | |
mask = self.squash_mask(mask) | |
p_matrix = make_safe(p_matrix * mask) | |
p_matrix = (p_matrix / K.sum(p_matrix, axis=-1, keepdims=True))*mask | |
return p_matrix | |
def get_config(self): | |
config = {} | |
base_config = super(ProbabilityTensor, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class SoftAttention(ProbabilityTensor): | |
def get_output_shape_for(self, input_shape): | |
# b,n,f -> b,f where f is weighted features summed across n | |
return (input_shape[0], input_shape[2]) | |
def compute_mask(self, x, mask=None): | |
if mask is None or mask.ndim==2: | |
return None | |
else: | |
raise Exception("Unexpected situation") | |
def call(self, x, mask=None): | |
# b,n,f -> b,f via b,n broadcasted | |
p_vectors = K.expand_dims(super(SoftAttention, self).call(x, mask), 2) | |
expanded_p = K.repeat_elements(p_vectors, K.shape(x)[2], axis=2) | |
return K.sum(expanded_p * x, axis=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment