Skip to content

Instantly share code, notes, and snippets.

@tomgrek
Created August 27, 2020 17:27
Show Gist options
  • Save tomgrek/61b0454e3cb3e010382e192c3982049f to your computer and use it in GitHub Desktop.
Save tomgrek/61b0454e3cb3e010382e192c3982049f to your computer and use it in GitHub Desktop.
Mask a sequence in Pytorch

I do this so many times, might as well make a gist.

Start with an example 4 (batch size) x 2 (sequence length) x 3 (embedding dim) tensor.

a = torch.tensor([[[1,2,3],[4,5,6]],[[10,11,12],[13,14,15]],[[6,7,8],[9,10,11]],[[12,13,14],[15,16,17]]])

tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[10, 11, 12],
         [13, 14, 15]],

        [[ 6,  7,  8],
         [ 9, 10, 11]],

        [[12, 13, 14],
         [15, 16, 17]]])

Say we want to zero out everything in the final dim beyond indices [0,1,2,3]: c = torch.tensor([0,1,2,3]).unsqueeze(-1)

tensor([[0],
        [1],
        [2],
        [3]])

Next we have to create a sequence: b = torch.arange(a.size(1))

tensor([0, 1])
d = b < c

tensor([[False, False, False],
        [ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]])
e = d.unsqueeze(1).repeat(1,2,1)

tensor([[[False, False, False],
         [False, False, False]],

        [[ True, False, False],
         [ True, False, False]],

        [[ True,  True, False],
         [ True,  True, False]],

        [[ True,  True,  True],
         [ True,  True,  True]]])
a * e
tensor([[[ 0,  0,  0],
         [ 0,  0,  0]],

        [[10,  0,  0],
         [13,  0,  0]],

        [[ 6,  7,  0],
         [ 9, 10,  0]],

        [[12, 13, 14],
         [15, 16, 17]]])

Done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment