Skip to content

Instantly share code, notes, and snippets.

@MSWon
Last active March 9, 2019 14:31
Show Gist options
  • Save MSWon/3af2691699ab10b43ed7c757c8829ff6 to your computer and use it in GitHub Desktop.
Save MSWon/3af2691699ab10b43ed7c757c8829ff6 to your computer and use it in GitHub Desktop.
padding with max seq len
import numpy as np
import tensorflow as tf
batch_size = 3
max_seq_len = 5
dim = 4
input_data = tf.placeholder(shape = (None,max_seq_len,dim), dtype = tf.float32)
seq_len = tf.placeholder(shape = [None], dtype = tf.int32)
max_len_range = np.array([range(max_seq_len)]*dim).T
max_len_range = np.reshape(max_len_range, (1,max_len_range.shape[0],max_len_range.shape[1]))
new_a = tf.tile(max_len_range, [batch_size,1,1])
seq_len_ = tf.reshape(seq_len, (-1,1,1))
max_len_mask = new_a < seq_len_
max_len_mask = tf.cast(max_len_mask , dtype = tf.float32)
pad_output = tf.multiply(input_data, max_len_mask)
sess = tf.Session()
sample_data = np.random.rand(batch_size, max_seq_len, dim)
sample_seq_len = [3,4,2]
sess.run(pad_output, feed_dict = {input_data : sample_data, seq_len : sample_seq_len})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment