Skip to content

Instantly share code, notes, and snippets.

@odanado
Created December 3, 2017 07:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save odanado/1f190f07c6956949d231bc73ea11f1fa to your computer and use it in GitHub Desktop.
Save odanado/1f190f07c6956949d231bc73ea11f1fa to your computer and use it in GitHub Desktop.
def make_mask(lens, mask_shape):
mask = np.ones(mask_shape).astype('uint8')
for i, l in enumerate(lens):
mask[i, :l] = 0
return mask
batch_size = 2
seq_len = 4
lens = np.array([2, 3])
shape = (batch_size, seq_len)
mask = make_mask(lens, shape)
print(mask)
# [[0 0 1 1]
# [0 0 0 1]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment