Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ArthurDelannoyazerty/61fd57fc798b1334e2cad6fd1f088c90 to your computer and use it in GitHub Desktop.
Save ArthurDelannoyazerty/61fd57fc798b1334e2cad6fd1f088c90 to your computer and use it in GitHub Desktop.
A small code that slice a tensor and can pass the backpropagation.
N = 4
x = torch.randn(N, 3)
out = torch.Tensor([1,0,1,0]).float().requires_grad_(True)
res1 = x * out[:, None]
idx = res1.nonzero()[:, 0].unique()
res2 = res1[idx]
# perform your operation here
res = res2.mean()
res.backward()
@ArthurDelannoyazerty
Copy link
Author

ArthurDelannoyazerty commented Jun 5, 2024

Code

N = 4
x = torch.randn(N, 3)                                           # Tensor to slice
out = torch.Tensor([1,0,1,0]).float().requires_grad_(True)      # Index to slice (binary)

out_t = out[:, None]                # Add dimensions
res1 = x * out_t                    # Set to 0 all non selected location
idx_nonz = res1.nonzero()           # Matrix coordinate of the non zero location
idx_slic = idx_nonz[:, 0]           # First coordinate of every non zero location
idx_uniq = idx_slic.unique()        # Unique of the first coordinate of every non zero location
res2 = res1[idx_uniq]               # Select the index selected (Here the backpropagation can pass (no idea why)) 

print('SHAPE ------------------------------------------------------------------------')
print('x        : ', x.shape)
print('out      : ', out.shape)
print('out_t    : ', out_t.shape)
print('res1     : ', res1.shape)
print('idx_nonz : ', idx_nonz.shape)
print('idx_slic : ', idx_slic.shape)
print('idx_uniq : ', idx_uniq.shape)
print('res2     : ', res2.shape)
print('\nVALUE ------------------------------------------------------------------------')
print('x        : ', x)
print('out      : ', out)
print('out_t    : ', out_t)
print('res1     : ', res1)
print('idx_nonz : ', idx_nonz)
print('idx_slic : ', idx_slic)
print('idx_uniq : ', idx_uniq)
print('res2     : ', res2)

# perform your operation here
res = res2.mean()

# backward
res.backward()

print('out grad : ', out.grad)
torchviz.make_dot(res)

Output

SHAPE ------------------------------------------------------------------------
x        :  torch.Size([4, 3])
out      :  torch.Size([4])
out_t    :  torch.Size([4, 1])
res1     :  torch.Size([4, 3])
idx_nonz :  torch.Size([6, 2])
idx_slic :  torch.Size([6])
idx_uniq :  torch.Size([2])
res2     :  torch.Size([2, 3])

VALUE ------------------------------------------------------------------------
x        :  tensor([[-0.7161,  0.9767, -0.6543],
        [ 0.2182, -0.5371, -0.9532],
        [ 0.4981, -0.3703, -1.2246],
        [-0.7489,  1.0009,  0.7083]])
out      :  tensor([1., 0., 1., 0.], requires_grad=True)
out_t    :  tensor([[1.],
        [0.],
        [1.],
        [0.]], grad_fn=<UnsqueezeBackward0>)
res1     :  tensor([[-0.7161,  0.9767, -0.6543],
        [ 0.0000, -0.0000, -0.0000],
        [ 0.4981, -0.3703, -1.2246],
        [-0.0000,  0.0000,  0.0000]], grad_fn=<MulBackward0>)
idx_nonz :  tensor([[0, 0],
        [0, 1],
        [0, 2],
        [2, 0],
        [2, 1],
        [2, 2]])
idx_slic :  tensor([0, 0, 0, 2, 2, 2])
idx_uniq :  tensor([0, 2])
res2     :  tensor([[-0.7161,  0.9767, -0.6543],
        [ 0.4981, -0.3703, -1.2246]], grad_fn=<IndexBackward0>)
out grad :  tensor([-0.0656,  0.0000, -0.1828,  0.0000])

Image backpropagation

output

Source

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment