Created
July 28, 2022 07:31
-
-
Save Sayam753/c9f93dc3212db645dfb5b1065202ed32 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
def test_b_bar(A, b): | |
A_tensor = torch.tensor(A, requires_grad=True) | |
b_tensor = torch.tensor(b, requires_grad=True) | |
c_tensor = torch.linalg.solve(A_tensor, b_tensor) | |
c_tensor.retain_grad() | |
c_tensor.sum().backward() | |
b_bar = torch.linalg.solve(torch.swapaxes(A_tensor, -1, -2), c_tensor.grad) | |
unidimensions_index = np.where(np.array(list(b_tensor.shape)) == 1)[0] | |
if len(unidimensions_index) > 0: | |
b_bar = torch.sum(b_bar, dim=tuple(unidimensions_index), keepdim=True) | |
if b.ndim == A.ndim - 1 and A.ndim > 2: | |
b_bar = torch.sum(b_bar, dim=0) | |
return (b_bar == b_tensor.grad).all().item() | |
def test_A_bar(A, b): | |
A_tensor = torch.tensor(A, requires_grad=True) | |
b_tensor = torch.tensor(b, requires_grad=True) | |
c_tensor = torch.linalg.solve(A_tensor, b_tensor) | |
c_tensor.retain_grad() | |
c_tensor.sum().backward() | |
b_bar = b_tensor.grad | |
c = c_tensor | |
if b.ndim == A.ndim - 1: | |
A_bar = -(b_bar[..., None, :] * c[..., None]) | |
A_bar = torch.swapaxes(A_bar, -1, -2) | |
else: | |
if A_shape == (7, 5, 5) and b_shape == (1, 5, 3): | |
breakpoint() | |
A_bar = -(b_bar @ torch.swapaxes(c, -1, -2)) | |
unidimensions_index = np.where(np.array(list(A_tensor.shape)) == 1)[0] | |
if len(unidimensions_index) > 0: | |
A_bar = torch.sum(A_bar, dim=tuple(unidimensions_index), keepdim=True) | |
return (A_bar == A_tensor.grad).all().item() | |
total_shapes = [ | |
# case when b.ndim == A.ndim - 1 | |
((5, 5), (5,)), | |
((1, 5, 5), (1, 5)), | |
((10, 5, 5), (1, 5)), | |
((10, 5, 5), (5, 5)), | |
((1, 5, 5), (10, 5)), | |
# cases when b.ndim == A.ndim - 1 | |
((1, 5, 5), (10, 5, 3)), | |
((10, 5, 5), (10, 5, 3)), | |
((1, 5, 5), (1, 5, 3)), | |
# failing | |
((10, 5, 5), (1, 5, 3)), | |
((10, 10, 5, 5), (1, 1, 5, 3)), | |
((7, 5, 5), (1, 5, 3)), | |
] | |
for A_shape, b_shape in total_shapes: | |
last_dim = A_shape[-1] | |
values = ( | |
np.random.randint(0, 10, size=last_dim * last_dim) | |
.reshape(last_dim, last_dim) | |
.astype(np.float32) | |
) | |
A = values @ values.T | |
A = np.full(A_shape, A) | |
b = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype(np.float32) | |
try: | |
flag = test_A_bar(A, b) | |
except: | |
print(A_shape, b_shape, "torch does not support") | |
else: | |
if flag is True: | |
pass | |
else: | |
print(A_shape, b_shape, "is failing") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment