Skip to content

Instantly share code, notes, and snippets.

@koreyou
Created August 26, 2017 09:48
Show Gist options
  • Save koreyou/1b5dd83f7c0af64a5b12280c0961a3f4 to your computer and use it in GitHub Desktop.
Save koreyou/1b5dd83f7c0af64a5b12280c0961a3f4 to your computer and use it in GitHub Desktop.
[Chainer] Mask sequence with length
def sequence_mask(x, length, value=0.):
xp = cuda.get_array_module(length.data)
# create permutation on (length.ndim + 1) dimension and expand dims until it has shame rank as x
perms = xp.arange(x.shape[length.ndim]).reshape(
[1] * length.ndim + [-1] + [1] * (x.ndim - length.ndim -1))
length = length.reshape([1] * (length.ndim - 1) + [-1] + [1] * (x.ndim - length.ndim))
pad = xp.ones_like(x) * value
mask = xp.broadcast_to(perms, x.shape) < length
return F.where(mask, x, pad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment