I do this so many times, might as well make a gist.
Start with an example 4 (batch size) x 2 (sequence length) x 3 (embedding dim) tensor.
a = torch.tensor([[[1,2,3],[4,5,6]],[[10,11,12],[13,14,15]],[[6,7,8],[9,10,11]],[[12,13,14],[15,16,17]]])
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[10, 11, 12],
[13, 14, 15]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17]]])
Say we want to zero out everything in the final dim beyond indices [0,1,2,3]: c = torch.tensor([0,1,2,3]).unsqueeze(-1)
tensor([[0],
[1],
[2],
[3]])
Next we have to create a sequence: b = torch.arange(a.size(1))
tensor([0, 1])
d = b < c
tensor([[False, False, False],
[ True, False, False],
[ True, True, False],
[ True, True, True]])
e = d.unsqueeze(1).repeat(1,2,1)
tensor([[[False, False, False],
[False, False, False]],
[[ True, False, False],
[ True, False, False]],
[[ True, True, False],
[ True, True, False]],
[[ True, True, True],
[ True, True, True]]])
a * e
tensor([[[ 0, 0, 0],
[ 0, 0, 0]],
[[10, 0, 0],
[13, 0, 0]],
[[ 6, 7, 0],
[ 9, 10, 0]],
[[12, 13, 14],
[15, 16, 17]]])
Done.