Skip to content

Instantly share code, notes, and snippets.

@yueyericardo
Last active August 2, 2022 17:58
Show Gist options
  • Save yueyericardo/24158433a2021c51eeef9c3e2722df99 to your computer and use it in GitHub Desktop.
Save yueyericardo/24158433a2021c51eeef9c3e2722df99 to your computer and use it in GitHub Desktop.
import time
import argparse
from functorch import vmap, jacrev, jacfwd
import torch
import torch.nn as nn
torch.backends.cuda.matmul.allow_tf32 = False
_ = torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
D1 = 2 # x, y
D2 = 3 # u, v, p
B = 10000
x = torch.randn(B, D1).to(device)
run_backward = False
model = nn.Sequential(
nn.Linear(D1, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, D2),
).to(device)
def predict(x):
torch.cuda.nvtx.range_push("forward")
out = model(x)
torch.cuda.nvtx.range_pop()
return out, out # return two outputs is needed for jacrev auxiliary object
def reference_hessian():
x_ = x.clone().requires_grad_()
ones = torch.ones(B, device=x.device)
pred, _ = predict(x_)
jacobian_rows = [None] * D2
hessian_rows = [None] * (D2 * D1)
for i in range(D2):
torch.cuda.nvtx.range_push("autograd jacobian")
jacobian_rows[i] = torch.autograd.grad(pred[:, i], x_, ones, create_graph=True)[
0
]
torch.cuda.nvtx.range_pop()
for i in range(D2):
for j in range(D1):
torch.cuda.nvtx.range_push("autograd hesian")
hessian_rows[i * D1 + j] = torch.autograd.grad(
jacobian_rows[i][:, j], x_, ones, create_graph=True
)[0]
torch.cuda.nvtx.range_pop()
jacobian = torch.stack(jacobian_rows) # [D2, B, D1]
hessian = torch.stack(hessian_rows) # [D2 * D1, B, D1]
if run_backward:
l = hessian.sum()
l.backward()
return hessian.transpose(0, 1), pred
def functorch_hessian():
x_ = x.clone().requires_grad_()
hessian, pred = vmap(
jacfwd(jacrev(predict, argnums=0, has_aux=True), argnums=0, has_aux=True),
in_dims=0,
)(
x_
) # [B, D2, D1, D1]
if run_backward:
l = hessian.sum()
l.backward()
return hessian, pred
def validate_result():
# test functorch result
ref_hes, ref_pred = reference_hessian()
ft_hes, ft_pred = functorch_hessian()
ref_hes = ref_hes.view_as(ft_hes)
print(f"max pred error: functorch: {(ref_pred - ft_pred).max():.2e}")
print(f"max hessian error: functorch: {(ref_hes - ft_hes).max():.2e}")
def benchmark(func):
N = 20
torch.cuda.synchronize()
start = time.time()
for i in range(N):
torch.cuda.nvtx.range_push(func.__name__)
_ = func()
torch.cuda.nvtx.range_pop()
torch.cuda.synchronize()
time_ms = ((time.time() - start) / N) * 1000
print(f"{func.__name__}: {time_ms:.3f} ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--backward", default=False, action="store_true")
args = parser.parse_args()
if args.backward:
run_backward = True
print("===== benchmark with backward =====")
else:
print("===== benchmark without backward =====")
validate_result()
# warm up
for i in range(10):
reference_hessian()
functorch_hessian()
# benchmark hessian
benchmark(reference_hessian)
benchmark(functorch_hessian)
import torch
import torch.utils.benchmark as benchmark
import torch.nn as nn
# --------------------------------------------------------
# just for test
_ = torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
D1 = 2 # x, y
D2 = 3 # u, v, p
B = 10000
x = torch.randn(B, D1).to(device).requires_grad_()
model = nn.Sequential(
nn.Linear(D1, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, D2),
).to(device)
pred = model(x)
loss = pred.sum()
torch.autograd.grad(outputs=loss, inputs=x, retain_graph=True)
loss.backward(retain_graph=True)
# --------------------------------------------------------
# benchmark autograd.grad
t0 = benchmark.Timer(
stmt="torch.autograd.grad(outputs=loss, inputs=x, retain_graph=True)",
setup="""
import torch
import torch.nn as nn
device = "cuda" if torch.cuda.is_available() else "cpu"
D1 = 2 # x, y
D2 = 3 # u, v, p
B = 1000
x = torch.randn(B, D1).to(device).requires_grad_()
model = nn.Sequential(
nn.Linear(D1, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, D2),
).to(device)
pred = model(x)
loss = pred.sum()
""",
num_threads=1,
)
print(t0.blocked_autorange())
print(t0.collect_callgrind())
# benchmark loss.backward
t1 = benchmark.Timer(
stmt="loss.backward(retain_graph=True)",
setup="""
import torch
import torch.nn as nn
device = "cuda" if torch.cuda.is_available() else "cpu"
D1 = 2 # x, y
D2 = 3 # u, v, p
B = 1000
x = torch.randn(B, D1).to(device).requires_grad_()
model = nn.Sequential(
nn.Linear(D1, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, D2),
).to(device)
pred = model(x)
loss = pred.sum()
""",
num_threads=1,
)
print(t1.blocked_autorange())
print(t1.collect_callgrind())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment