Skip to content

Instantly share code, notes, and snippets.

@zou3519
Last active July 31, 2022 10:43
Show Gist options
  • Save zou3519/3869d460f8bcb12799967e08a5998d9c to your computer and use it in GitHub Desktop.
Save zou3519/3869d460f8bcb12799967e08a5998d9c to your computer and use it in GitHub Desktop.
import time
import argparse
from functorch import vmap, jacrev, jacfwd
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.backends.cuda.matmul.allow_tf32 = False
_ = torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
D1 = 2 # x, y
D2 = 3 # u, v, p
B = 10000
x = torch.randn(B, D1).to(device)
run_backward = False
model = nn.Sequential(
nn.Linear(D1, 512, bias=False),
nn.ReLU(),
nn.Linear(512, D2, bias=False),
).to(device)
weights = (model[0].weight, model[2].weight)
def old_linear(x, weight):
return x.unsqueeze(0).mm(weight.t()).squeeze_(0)
def model(weights, x, use_old_linear=False):
weight1, weight2 = weights
x = old_linear(x, weight1) if use_old_linear else F.linear(x, weight1)
x = x.relu()
x = old_linear(x, weight2) if use_old_linear else F.linear(x, weight2)
return x
def predict(x, nvtx=True, use_old_linear=False):
if nvtx:
torch.cuda.nvtx.range_push("forward")
out = model(weights, x, use_old_linear)
if nvtx:
torch.cuda.nvtx.range_pop()
return out, out # return two outputs is needed for jacrev auxiliary object
def f(x, nvtx, use_old_linear):
return vmap(jacrev(predict), (0, None, None))(x, nvtx, use_old_linear)
def quantity_with_old_linear():
return f(x, True, True)
def quantity_with_new_linear():
return f(x, True, False)
def benchmark(func):
N = 20
start = time.time()
torch.cuda.synchronize()
for i in range(N):
torch.cuda.nvtx.range_push(func.__name__)
_ = func()
torch.cuda.nvtx.range_pop()
torch.cuda.synchronize()
time_ms = ((time.time() - start) / N) * 1000
print(f"{func.__name__}: {time_ms:.3f} ms")
from functorch import make_fx
print("quantity using old linear")
gm = make_fx(f)(x, False, True)
print(gm.code)
print("quantity using new linear")
gm = make_fx(f)(x, False, False)
print(gm.code)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--backward", default=False, action="store_true")
args = parser.parse_args()
if args.backward:
run_backward = True
print("===== benchmark with backward =====")
else:
print("===== benchmark without backward =====")
# warm up
for i in range(10):
quantity_with_old_linear()
quantity_with_new_linear()
# benchmark hessian
benchmark(quantity_with_new_linear)
benchmark(quantity_with_old_linear)
"""
quantity using old linear
def forward(self, x_1, nvtx_1, use_old_linear_1):
unsqueeze = torch.ops.aten.unsqueeze(x_1, 1); x_1 = None
_reshape_alias = torch.ops.aten._reshape_alias(unsqueeze, [10000, 2], [2, 1]); unsqueeze = None
_tensor_constant0 = self._tensor_constant0
mm = torch.ops.aten.mm(_reshape_alias, _tensor_constant0); _reshape_alias = _tensor_constant0 = None
_unsafe_view = torch.ops.aten._unsafe_view(mm, [10000, 1, 512]); mm = None
squeeze_ = torch.ops.aten.squeeze_(_unsafe_view, 1); _unsafe_view = None
relu = torch.ops.aten.relu(squeeze_); squeeze_ = None
detach = torch.ops.aten.detach(relu)
unsqueeze_1 = torch.ops.aten.unsqueeze(relu, 1)
_reshape_alias_1 = torch.ops.aten._reshape_alias(unsqueeze_1, [10000, 512], [512, 1]); unsqueeze_1 = None
_tensor_constant1 = self._tensor_constant1
mm_1 = torch.ops.aten.mm(_reshape_alias_1, _tensor_constant1); _reshape_alias_1 = _tensor_constant1 = None
_unsafe_view_1 = torch.ops.aten._unsafe_view(mm_1, [10000, 1, 3]); mm_1 = None
squeeze__1 = torch.ops.aten.squeeze_(_unsafe_view_1, 1); _unsafe_view_1 = None
new_empty = torch.ops.aten.new_empty(squeeze__1, [10000, 6, 3], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
zero_ = torch.ops.aten.zero_(new_empty); new_empty = None
new_empty_1 = torch.ops.aten.new_empty(squeeze__1, [10000, 6, 3], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); squeeze__1 = None
zero__1 = torch.ops.aten.zero_(new_empty_1); new_empty_1 = None
diagonal = torch.ops.aten.diagonal(zero_, 0, 1, 2)
fill_ = torch.ops.aten.fill_(diagonal, 1); diagonal = None
diagonal_1 = torch.ops.aten.diagonal(zero__1, -3, 1, 2)
fill__1 = torch.ops.aten.fill_(diagonal_1, 1); diagonal_1 = None
view = torch.ops.aten.view(zero_, [10000, 6, 3]); zero_ = None
view_1 = torch.ops.aten.view(zero__1, [10000, 6, 3]); zero__1 = None
add = torch.ops.aten.add(view, view_1); view = view_1 = None
unsqueeze_2 = torch.ops.aten.unsqueeze(add, 2); add = None
_reshape_alias_2 = torch.ops.aten._reshape_alias(unsqueeze_2, [10000, 6, 3], [18, 3, 1]); unsqueeze_2 = None
_reshape_alias_3 = torch.ops.aten._reshape_alias(_reshape_alias_2, [60000, 3], [3, 1]); _reshape_alias_2 = None
_tensor_constant2 = self._tensor_constant2
mm_2 = torch.ops.aten.mm(_reshape_alias_3, _tensor_constant2); _reshape_alias_3 = _tensor_constant2 = None
_unsafe_view_2 = torch.ops.aten._unsafe_view(mm_2, [10000, 6, 512]); mm_2 = None
_unsafe_view_3 = torch.ops.aten._unsafe_view(_unsafe_view_2, [10000, 6, 1, 512]); _unsafe_view_2 = None
squeeze = torch.ops.aten.squeeze(_unsafe_view_3, 2); _unsafe_view_3 = None
view_2 = torch.ops.aten.view(relu, [10000, 1, 512]); relu = None
threshold_backward = torch.ops.aten.threshold_backward(squeeze, view_2, 0); squeeze = view_2 = None
unsqueeze_3 = torch.ops.aten.unsqueeze(threshold_backward, 2); threshold_backward = None
_reshape_alias_4 = torch.ops.aten._reshape_alias(unsqueeze_3, [10000, 6, 512], [3072, 512, 1]); unsqueeze_3 = None
_reshape_alias_5 = torch.ops.aten._reshape_alias(_reshape_alias_4, [60000, 512], [512, 1]); _reshape_alias_4 = None
_tensor_constant3 = self._tensor_constant3
mm_3 = torch.ops.aten.mm(_reshape_alias_5, _tensor_constant3); _reshape_alias_5 = _tensor_constant3 = None
_unsafe_view_4 = torch.ops.aten._unsafe_view(mm_3, [10000, 6, 2]); mm_3 = None
_unsafe_view_5 = torch.ops.aten._unsafe_view(_unsafe_view_4, [10000, 6, 1, 2]); _unsafe_view_4 = None
squeeze_1 = torch.ops.aten.squeeze(_unsafe_view_5, 2); _unsafe_view_5 = None
split_with_sizes = torch.ops.aten.split_with_sizes(squeeze_1, [3, 3], 1); squeeze_1 = None
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]; split_with_sizes = None
view_3 = torch.ops.aten.view(getitem, [10000, 3, 2]); getitem = None
view_4 = torch.ops.aten.view(getitem_1, [10000, 3, 2]); getitem_1 = None
return (view_3, view_4)
quantity using new linear
def forward(self, x_1, nvtx_1, use_old_linear_1):
permute = torch.ops.aten.permute(x_1, [1, 0]); x_1 = None
_tensor_constant0 = self._tensor_constant0
mm = torch.ops.aten.mm(_tensor_constant0, permute); _tensor_constant0 = permute = None
relu = torch.ops.aten.relu(mm); mm = None
detach = torch.ops.aten.detach(relu)
permute_1 = torch.ops.aten.permute(relu, [0, 1])
_tensor_constant1 = self._tensor_constant1
mm_1 = torch.ops.aten.mm(_tensor_constant1, permute_1); _tensor_constant1 = permute_1 = None
new_empty = torch.ops.aten.new_empty(mm_1, [10000, 6, 3], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
zero_ = torch.ops.aten.zero_(new_empty); new_empty = None
new_empty_1 = torch.ops.aten.new_empty(mm_1, [10000, 6, 3], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); mm_1 = None
zero__1 = torch.ops.aten.zero_(new_empty_1); new_empty_1 = None
diagonal = torch.ops.aten.diagonal(zero_, 0, 1, 2)
fill_ = torch.ops.aten.fill_(diagonal, 1); diagonal = None
diagonal_1 = torch.ops.aten.diagonal(zero__1, -3, 1, 2)
fill__1 = torch.ops.aten.fill_(diagonal_1, 1); diagonal_1 = None
view = torch.ops.aten.view(zero_, [10000, 6, 3]); zero_ = None
view_1 = torch.ops.aten.view(zero__1, [10000, 6, 3]); zero__1 = None
add = torch.ops.aten.add(view, view_1); view = view_1 = None
permute_2 = torch.ops.aten.permute(add, [0, 2, 1]); add = None
expand = torch.ops.aten.expand(permute_2, [10000, 3, 6]); permute_2 = None
_reshape_alias = torch.ops.aten._reshape_alias(expand, [10000, 3, 6], [18, 1, 3]); expand = None
_tensor_constant2 = self._tensor_constant2
bmm = torch.ops.aten.bmm(_tensor_constant2, _reshape_alias); _tensor_constant2 = _reshape_alias = None
_unsafe_view = torch.ops.aten._unsafe_view(bmm, [10000, 512, 6]); bmm = None
permute_3 = torch.ops.aten.permute(_unsafe_view, [0, 2, 1]); _unsafe_view = None
permute_4 = torch.ops.aten.permute(relu, [1, 0]); relu = None
view_2 = torch.ops.aten.view(permute_4, [10000, 1, 512]); permute_4 = None
threshold_backward = torch.ops.aten.threshold_backward(permute_3, view_2, 0); permute_3 = view_2 = None
permute_5 = torch.ops.aten.permute(threshold_backward, [0, 2, 1]); threshold_backward = None
transpose = torch.ops.aten.transpose(permute_5, -2, -1); permute_5 = None
_reshape_alias_1 = torch.ops.aten._reshape_alias(transpose, [60000, 512], [1, 60000]); transpose = None
_tensor_constant3 = self._tensor_constant3
mm_2 = torch.ops.aten.mm(_reshape_alias_1, _tensor_constant3); _reshape_alias_1 = _tensor_constant3 = None
_unsafe_view_1 = torch.ops.aten._unsafe_view(mm_2, [10000, 6, 2]); mm_2 = None
transpose_1 = torch.ops.aten.transpose(_unsafe_view_1, -2, -1); _unsafe_view_1 = None
clone = torch.ops.aten.clone(transpose_1, memory_format = torch.contiguous_format); transpose_1 = None
permute_6 = torch.ops.aten.permute(clone, [0, 2, 1]); clone = None
split_with_sizes = torch.ops.aten.split_with_sizes(permute_6, [3, 3], 1); permute_6 = None
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]; split_with_sizes = None
view_3 = torch.ops.aten.view(getitem, [10000, 3, 2]); getitem = None
view_4 = torch.ops.aten.view(getitem_1, [10000, 3, 2]); getitem_1 = None
return (view_3, view_4)
===== benchmark without backward =====
quantity_with_new_linear: 7.190 ms
quantity_with_old_linear: 1.366 ms
"""
@IvanYashchuk
Copy link

On the latest master make_fx call from this script raises an error:

File ~/dev/pytorch/master/torch/fx/experimental/proxy_tensor.py:134, in proxy_call(func_overload, args, kwargs)
    132     if t.constant is not None:
    133         with maybe_disable_fake_tensor_mode():
--> 134             return t.constant.item()
    135     raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
    136                        "It's likely that this is caused by data-dependent control flow or similar."
    137                        "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")
    139 def unwrap_proxy(e):

AttributeError: 'list' object has no attribute 'item'

using tracing_mode="symbolic" raises another error:

File ~/dev/pytorch/master/functorch/functorch/_src/vmap.py:484, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    483 def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs):
--> 484     vmap_level = _vmap_increment_nesting(batch_size, randomness)
    485     try:
    486         batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)

TypeError: _vmap_increment_nesting(): incompatible function arguments. The following argument types are supported:
    1. (arg0: int, arg1: str) -> int

Invoked with: <torch.SymIntNode object at 0x7f6275eec3f0>, 'error'

and tracing_mode="fake" segfaults.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment