Skip to content

Instantly share code, notes, and snippets.

@Chiang97912
Created August 1, 2019 13:12
Show Gist options
  • Save Chiang97912/49bfe2085fe3e2b5625258a523fe2af2 to your computer and use it in GitHub Desktop.
Save Chiang97912/49bfe2085fe3e2b5625258a523fe2af2 to your computer and use it in GitHub Desktop.
attention
def attention_3d_block(inputs):
# inputs.shape = (batch_size, time_steps, input_dim)
input_dim = int(inputs.shape[2])
a = Permute((2, 1))(inputs)
a = Reshape((input_dim, TIME_STEPS))(a) # this line is not useful. It's just to know which dimension is what.
a = Dense(TIME_STEPS, activation='softmax')(a)
if SINGLE_ATTENTION_VECTOR:
a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a)
a = RepeatVector(input_dim)(a)
a_probs = Permute((2, 1), name='attention_vec')(a)
output_attention_mul = merge([inputs, a_probs], name='attention_mul', mode='mul')
return output_attention_mul
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment