Created
June 5, 2024 15:00
-
-
Save ArthurDelannoyazerty/61fd57fc798b1334e2cad6fd1f088c90 to your computer and use it in GitHub Desktop.
A small code that slice a tensor and can pass the backpropagation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
N = 4 | |
x = torch.randn(N, 3) | |
out = torch.Tensor([1,0,1,0]).float().requires_grad_(True) | |
res1 = x * out[:, None] | |
idx = res1.nonzero()[:, 0].unique() | |
res2 = res1[idx] | |
# perform your operation here | |
res = res2.mean() | |
res.backward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Code
Output
Image backpropagation
Source