Skip to content

Instantly share code, notes, and snippets.

@Sayam753
Created July 28, 2022 07:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Sayam753/c9f93dc3212db645dfb5b1065202ed32 to your computer and use it in GitHub Desktop.
Save Sayam753/c9f93dc3212db645dfb5b1065202ed32 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
def test_b_bar(A, b):
A_tensor = torch.tensor(A, requires_grad=True)
b_tensor = torch.tensor(b, requires_grad=True)
c_tensor = torch.linalg.solve(A_tensor, b_tensor)
c_tensor.retain_grad()
c_tensor.sum().backward()
b_bar = torch.linalg.solve(torch.swapaxes(A_tensor, -1, -2), c_tensor.grad)
unidimensions_index = np.where(np.array(list(b_tensor.shape)) == 1)[0]
if len(unidimensions_index) > 0:
b_bar = torch.sum(b_bar, dim=tuple(unidimensions_index), keepdim=True)
if b.ndim == A.ndim - 1 and A.ndim > 2:
b_bar = torch.sum(b_bar, dim=0)
return (b_bar == b_tensor.grad).all().item()
def test_A_bar(A, b):
A_tensor = torch.tensor(A, requires_grad=True)
b_tensor = torch.tensor(b, requires_grad=True)
c_tensor = torch.linalg.solve(A_tensor, b_tensor)
c_tensor.retain_grad()
c_tensor.sum().backward()
b_bar = b_tensor.grad
c = c_tensor
if b.ndim == A.ndim - 1:
A_bar = -(b_bar[..., None, :] * c[..., None])
A_bar = torch.swapaxes(A_bar, -1, -2)
else:
if A_shape == (7, 5, 5) and b_shape == (1, 5, 3):
breakpoint()
A_bar = -(b_bar @ torch.swapaxes(c, -1, -2))
unidimensions_index = np.where(np.array(list(A_tensor.shape)) == 1)[0]
if len(unidimensions_index) > 0:
A_bar = torch.sum(A_bar, dim=tuple(unidimensions_index), keepdim=True)
return (A_bar == A_tensor.grad).all().item()
total_shapes = [
# case when b.ndim == A.ndim - 1
((5, 5), (5,)),
((1, 5, 5), (1, 5)),
((10, 5, 5), (1, 5)),
((10, 5, 5), (5, 5)),
((1, 5, 5), (10, 5)),
# cases when b.ndim == A.ndim - 1
((1, 5, 5), (10, 5, 3)),
((10, 5, 5), (10, 5, 3)),
((1, 5, 5), (1, 5, 3)),
# failing
((10, 5, 5), (1, 5, 3)),
((10, 10, 5, 5), (1, 1, 5, 3)),
((7, 5, 5), (1, 5, 3)),
]
for A_shape, b_shape in total_shapes:
last_dim = A_shape[-1]
values = (
np.random.randint(0, 10, size=last_dim * last_dim)
.reshape(last_dim, last_dim)
.astype(np.float32)
)
A = values @ values.T
A = np.full(A_shape, A)
b = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype(np.float32)
try:
flag = test_A_bar(A, b)
except:
print(A_shape, b_shape, "torch does not support")
else:
if flag is True:
pass
else:
print(A_shape, b_shape, "is failing")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment