Skip to content

Instantly share code, notes, and snippets.

@rkube
Last active July 9, 2021 20:31
Show Gist options
  • Save rkube/eaca642c1af516fdd04ebf69428aaaa4 to your computer and use it in GitHub Desktop.
Save rkube/eaca642c1af516fdd04ebf69428aaaa4 to your computer and use it in GitHub Desktop.
Verification of qr-pullback with pytorch
import torch
def f1(A):
q, r = torch.qr(A)
return q.sum()
def f2(A):
q, r = torch.qr(A)
return r.sum()
V1 = torch.tensor([[0.0107703 , 0.2082, 0.257278 , 0.509395],
[0.886279, 0.29217, 0.579832, 0.470256],
[0.693459, 0.594034, 0.0419015, 0.0499426],
[0.225352, 0.481055, 0.0644581, 0.961967]], requires_grad=True)
V2 = torch.tensor([[ 0.502334, 0.110096, 1.1055, -0.602707],
[ -0.516984, -0.251176, -1.10673, -1.27967],
[ -0.560501, 0.369714, -3.21136, 0.997317],
[ -0.0192918, 0.0721164, -0.0740145, 0.302423],
[ 0.128064, -1.50343, 0.150976, -0.036446],
[ 1.85278, 1.56417 , 0.769278, 0.141974],
[ -0.827763, -1.39674, -0.310153, 0.521273]], requires_grad=True)
V3 = torch.tensor([[ 0.896748 , -1.09122 , 0.165837, -0.541716, 2.41747 , -0.987177, 0.447358, 2.06353 , -1.58492 , 0.259671, -1.15533],
[ -0.51353 , -0.580517, -0.408438, -0.686494, -0.307974 , 1.48217 , -0.396211, -1.41453 , -1.975 , -1.04861 , 0.81871],
[ -0.764799 , -0.315437, -1.00978 , -0.712932, 1.2453 , -0.522772, 0.366773, 0.134475, 0.076418, 0.199609, 0.15131],
[ -1.54143 , -1.36145 , -0.543805, -0.327059, -0.0499502, -1.5807 , 0.621673, -0.750421, 0.799335, -0.80678 , 1.24723],
[ -0.0801625 , -0.114457, -1.22672 , 0.514836, -1.05478 , 0.131842, 0.182588, 0.170778, -0.734161, 0.197804, -0.981132]], requires_grad=True)
def test_f(fun, V):
V_input = V.clone().detach().requires_grad_(True)
result = fun(V_input)
result.backward()
print(V_input.grad)
test_f(f1, V1)
test_f(f1, V2)
test_f(f1, V3)
test_f(f2, V1)
test_f(f2, V2)
test_f(f2, V3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment