Skip to content

Instantly share code, notes, and snippets.

@berlino
Created October 20, 2023 19:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save berlino/2ede409cc7a8822c19d7e3e7492bc588 to your computer and use it in GitHub Desktop.
Save berlino/2ede409cc7a8822c19d7e3e7492bc588 to your computer and use it in GitHub Desktop.
import torch
if __name__ == "__main__":
N, d = 128, 256
dtype = torch.float32
A = torch.randn((N, N), dtype=dtype).cuda().requires_grad_(True)
p = torch.randn((N, ), dtype=dtype).uniform_(0.1, 0.9).cuda().requires_grad_(True)
o1 = A @ p
o1.sum().backward()
A_grad = A.grad.clone()
p_grad = p.grad.clone()
A.grad.zero_()
p.grad.zero_()
A = A.clone().detach().requires_grad_(True)
p = p.clone().detach().requires_grad_(True)
logA = torch.log(torch.abs(A))
logp = torch.log(p)
signA = torch.sign(A)
inter = (torch.exp(logp + logA)*signA).sum(dim=1)
outsigns = torch.sign(inter)
outprod = torch.log(torch.abs(inter))
o2 = outsigns * outprod.exp()
o2.sum().backward()
print("fwd diff", torch.abs(o1 - o2).max())
print("A grad diff", torch.abs(A_grad - A.grad).max())
print("p grad diff", torch.abs(p_grad - p.grad).max())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment