Skip to content

Instantly share code, notes, and snippets.

@MSWon
Created May 7, 2019 07:30
Show Gist options
  • Save MSWon/6f617e3b5889c4bb788e892ac5639794 to your computer and use it in GitHub Desktop.
Save MSWon/6f617e3b5889c4bb788e892ac5639794 to your computer and use it in GitHub Desktop.
tf selfattention mask with diag zero value
import tensorflow as tf
max_seq_len = 6
seq_len = [3,5,4]
row_vector = tf.range(0,max_seq_len,1) ## [, max_seq_len]
matrix = tf.cast(tf.expand_dims(seq_len,-1), tf.int32) ## [batch_size, 1]
t = tf.cast(row_vector < matrix, tf.float32) ## [batch_size, max_seq_len]
t = tf.expand_dims(t, -1) ## [batch_size, max_seq_len, 1]
masks = t * tf.transpose(t, [0,2,1]) ## [batch_size, max_seq_len, max_seq_len]
new_masks = tf.linalg.set_diag(masks, tf.zeros(masks.shape[0:-1]))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print("masks : ")
print(sess.run(masks))
print("new masks : ")
print(sess.run(new_masks))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment