Skip to content

Instantly share code, notes, and snippets.

@xmfan
Created April 2, 2024 00:57
Show Gist options
  • Save xmfan/1ebe56aca76f66ac7545b17be8b4d4ce to your computer and use it in GitHub Desktop.
Save xmfan/1ebe56aca76f66ac7545b17be8b4d4ce to your computer and use it in GitHub Desktop.
test_basic, before boxing inputs
/home/xmfan/.conda/envs/benchmarks/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
INFO:torch._dynamo.compiled_autograd.__compiled_autograd:TRACED GRAPH
===== Compiled autograd graph =====
<eval_with_key>.0 class CompiledAutograd(torch.nn.Module):
def forward(self, inputs, sizes, hooks):
# No stacktrace found for following nodes
getitem: "f32[]" = inputs[0]
getitem_1: "f32[2, 4]" = inputs[1]
getitem_2: "f32[2, 4]" = inputs[2]
getitem_3: "f32[4, 4]" = inputs[3]
getitem_4: "f32[4, 4]" = inputs[4]
getitem_5: "f32[2, 4]" = inputs[5]
getitem_6: "f32[2, 4]" = inputs[6]
getitem_7: "f32[4, 4]" = inputs[7]
getitem_8: "f32[4]" = inputs[8]
getitem_9: "f32[4]" = inputs[9]; inputs = None
expand: "f32[2, 4]" = torch.ops.aten.expand.default(getitem, [2, 4]); getitem = None
threshold_backward: "f32[2, 4]" = torch.ops.aten.threshold_backward.default(expand, getitem_1, 0); expand = getitem_1 = None
t: "f32[4, 4]" = torch.ops.aten.t.default(getitem_3); getitem_3 = None
mm: "f32[2, 4]" = torch.ops.aten.mm.default(threshold_backward, t); t = None
t_1: "f32[4, 2]" = torch.ops.aten.t.default(threshold_backward)
mm_1: "f32[4, 4]" = torch.ops.aten.mm.default(t_1, getitem_2); t_1 = getitem_2 = None
t_2: "f32[4, 4]" = torch.ops.aten.t.default(mm_1); mm_1 = None
sum_1: "f32[1, 4]" = torch.ops.aten.sum.dim_IntList(threshold_backward, [0], True); threshold_backward = None
view: "f32[4]" = torch.ops.aten.view.default(sum_1, [4]); sum_1 = None
accumulate_grad__3 = torch.ops.inductor.accumulate_grad_.default(getitem_9, view); getitem_9 = view = None
t_3: "f32[4, 4]" = torch.ops.aten.t.default(t_2); t_2 = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_4, t_3); getitem_4 = t_3 = None
threshold_backward_1: "f32[2, 4]" = torch.ops.aten.threshold_backward.default(mm, getitem_5, 0); mm = getitem_5 = None
t_4: "f32[4, 2]" = torch.ops.aten.t.default(threshold_backward_1)
mm_2: "f32[4, 4]" = torch.ops.aten.mm.default(t_4, getitem_6); t_4 = getitem_6 = None
t_5: "f32[4, 4]" = torch.ops.aten.t.default(mm_2); mm_2 = None
sum_2: "f32[1, 4]" = torch.ops.aten.sum.dim_IntList(threshold_backward_1, [0], True); threshold_backward_1 = None
view_1: "f32[4]" = torch.ops.aten.view.default(sum_2, [4]); sum_2 = None
accumulate_grad__2 = torch.ops.inductor.accumulate_grad_.default(getitem_8, view_1); getitem_8 = view_1 = None
t_6: "f32[4, 4]" = torch.ops.aten.t.default(t_5); t_5 = None
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_7, t_6); getitem_7 = t_6 = None
return []
DEBUG:torch._dynamo.output_graph.__graph_code:TRACED GRAPH
===== __compiled_fn_0 =====
/home/xmfan/core/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_inputs_0_ : torch.Tensor, L_inputs_1_ : torch.Tensor, L_inputs_2_ : torch.Tensor, L_inputs_3_ : torch.Tensor, L_inputs_5_ : torch.Tensor, L_inputs_6_ : torch.Tensor):
l_inputs_0_ = L_inputs_0_
l_inputs_1_ = L_inputs_1_
l_inputs_2_ = L_inputs_2_
l_inputs_3_ = L_inputs_3_
l_inputs_5_ = L_inputs_5_
l_inputs_6_ = L_inputs_6_
# File: <eval_with_key>.0:15 in forward, code: expand = torch.ops.aten.expand.default(getitem, [2, 4]); getitem = None
expand = torch.ops.aten.expand.default(l_inputs_0_, [2, 4]); l_inputs_0_ = None
# File: <eval_with_key>.0:16 in forward, code: threshold_backward = torch.ops.aten.threshold_backward.default(expand, getitem_1, 0); expand = getitem_1 = None
threshold_backward = torch.ops.aten.threshold_backward.default(expand, l_inputs_1_, 0); expand = l_inputs_1_ = None
# File: <eval_with_key>.0:17 in forward, code: t = torch.ops.aten.t.default(getitem_3); getitem_3 = None
t = torch.ops.aten.t.default(l_inputs_3_); l_inputs_3_ = None
# File: <eval_with_key>.0:18 in forward, code: mm = torch.ops.aten.mm.default(threshold_backward, t); t = None
mm = torch.ops.aten.mm.default(threshold_backward, t); t = None
# File: <eval_with_key>.0:19 in forward, code: t_1 = torch.ops.aten.t.default(threshold_backward)
t_1 = torch.ops.aten.t.default(threshold_backward)
# File: <eval_with_key>.0:20 in forward, code: mm_1 = torch.ops.aten.mm.default(t_1, getitem_2); t_1 = getitem_2 = None
mm_1 = torch.ops.aten.mm.default(t_1, l_inputs_2_); t_1 = l_inputs_2_ = None
# File: <eval_with_key>.0:21 in forward, code: t_2 = torch.ops.aten.t.default(mm_1); mm_1 = None
t_2 = torch.ops.aten.t.default(mm_1); mm_1 = None
# File: <eval_with_key>.0:22 in forward, code: sum_1 = torch.ops.aten.sum.dim_IntList(threshold_backward, [0], True); threshold_backward = None
sum_1 = torch.ops.aten.sum.dim_IntList(threshold_backward, [0], True); threshold_backward = None
# File: <eval_with_key>.0:23 in forward, code: view = torch.ops.aten.view.default(sum_1, [4]); sum_1 = None
view = torch.ops.aten.view.default(sum_1, [4]); sum_1 = None
# File: /home/xmfan/core/pytorch/torch/_dynamo/polyfill.py:44 in accumulate_grad, code: new_grad = torch.clone(new_grad)
new_grad = torch.clone(view); view = None
# File: <eval_with_key>.0:25 in forward, code: t_3 = torch.ops.aten.t.default(t_2); t_2 = None
t_3 = torch.ops.aten.t.default(t_2); t_2 = None
# File: /home/xmfan/core/pytorch/torch/_dynamo/polyfill.py:44 in accumulate_grad, code: new_grad = torch.clone(new_grad)
new_grad_1 = torch.clone(t_3); t_3 = None
# File: <eval_with_key>.0:27 in forward, code: threshold_backward_1 = torch.ops.aten.threshold_backward.default(mm, getitem_5, 0); mm = getitem_5 = None
threshold_backward_1 = torch.ops.aten.threshold_backward.default(mm, l_inputs_5_, 0); mm = l_inputs_5_ = None
# File: <eval_with_key>.0:28 in forward, code: t_4 = torch.ops.aten.t.default(threshold_backward_1)
t_4 = torch.ops.aten.t.default(threshold_backward_1)
# File: <eval_with_key>.0:29 in forward, code: mm_2 = torch.ops.aten.mm.default(t_4, getitem_6); t_4 = getitem_6 = None
mm_2 = torch.ops.aten.mm.default(t_4, l_inputs_6_); t_4 = l_inputs_6_ = None
# File: <eval_with_key>.0:30 in forward, code: t_5 = torch.ops.aten.t.default(mm_2); mm_2 = None
t_5 = torch.ops.aten.t.default(mm_2); mm_2 = None
# File: <eval_with_key>.0:31 in forward, code: sum_2 = torch.ops.aten.sum.dim_IntList(threshold_backward_1, [0], True); threshold_backward_1 = None
sum_2 = torch.ops.aten.sum.dim_IntList(threshold_backward_1, [0], True); threshold_backward_1 = None
# File: <eval_with_key>.0:32 in forward, code: view_1 = torch.ops.aten.view.default(sum_2, [4]); sum_2 = None
view_1 = torch.ops.aten.view.default(sum_2, [4]); sum_2 = None
# File: /home/xmfan/core/pytorch/torch/_dynamo/polyfill.py:44 in accumulate_grad, code: new_grad = torch.clone(new_grad)
new_grad_2 = torch.clone(view_1); view_1 = None
# File: <eval_with_key>.0:34 in forward, code: t_6 = torch.ops.aten.t.default(t_5); t_5 = None
t_6 = torch.ops.aten.t.default(t_5); t_5 = None
# File: /home/xmfan/core/pytorch/torch/_dynamo/polyfill.py:44 in accumulate_grad, code: new_grad = torch.clone(new_grad)
new_grad_3 = torch.clone(t_6); t_6 = None
return (new_grad_1, new_grad_3, new_grad_2, new_grad)
INFO:torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs:TRACED GRAPH
===== Forward graph 0 =====
/home/xmfan/core/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[]", arg1_1: "f32[2, 4]", arg2_1: "f32[2, 4]", arg3_1: "f32[4, 4]", arg4_1: "f32[2, 4]", arg5_1: "f32[2, 4]"):
# File: <eval_with_key>.0:15 in forward, code: expand = torch.ops.aten.expand.default(getitem, [2, 4]); getitem = None
expand: "f32[2, 4]" = torch.ops.aten.expand.default(arg0_1, [2, 4]); arg0_1 = None
# File: <eval_with_key>.0:16 in forward, code: threshold_backward = torch.ops.aten.threshold_backward.default(expand, getitem_1, 0); expand = getitem_1 = None
le: "b8[2, 4]" = torch.ops.aten.le.Scalar(arg1_1, 0); arg1_1 = None
scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
where: "f32[2, 4]" = torch.ops.aten.where.self(le, scalar_tensor, expand); le = scalar_tensor = expand = None
# File: <eval_with_key>.0:17 in forward, code: t = torch.ops.aten.t.default(getitem_3); getitem_3 = None
permute: "f32[4, 4]" = torch.ops.aten.permute.default(arg3_1, [1, 0]); arg3_1 = None
# File: <eval_with_key>.0:18 in forward, code: mm = torch.ops.aten.mm.default(threshold_backward, t); t = None
mm: "f32[2, 4]" = torch.ops.aten.mm.default(where, permute); permute = None
# File: <eval_with_key>.0:19 in forward, code: t_1 = torch.ops.aten.t.default(threshold_backward)
permute_1: "f32[4, 2]" = torch.ops.aten.permute.default(where, [1, 0])
# File: <eval_with_key>.0:20 in forward, code: mm_1 = torch.ops.aten.mm.default(t_1, getitem_2); t_1 = getitem_2 = None
mm_1: "f32[4, 4]" = torch.ops.aten.mm.default(permute_1, arg2_1); permute_1 = arg2_1 = None
# File: <eval_with_key>.0:21 in forward, code: t_2 = torch.ops.aten.t.default(mm_1); mm_1 = None
permute_2: "f32[4, 4]" = torch.ops.aten.permute.default(mm_1, [1, 0]); mm_1 = None
# File: <eval_with_key>.0:22 in forward, code: sum_1 = torch.ops.aten.sum.dim_IntList(threshold_backward, [0], True); threshold_backward = None
sum_1: "f32[1, 4]" = torch.ops.aten.sum.dim_IntList(where, [0], True); where = None
# File: <eval_with_key>.0:23 in forward, code: view = torch.ops.aten.view.default(sum_1, [4]); sum_1 = None
view: "f32[4]" = torch.ops.aten.view.default(sum_1, [4]); sum_1 = None
# File: /home/xmfan/core/pytorch/torch/_dynamo/polyfill.py:44 in accumulate_grad, code: new_grad = torch.clone(new_grad)
clone: "f32[4]" = torch.ops.aten.clone.default(view); view = None
# File: <eval_with_key>.0:25 in forward, code: t_3 = torch.ops.aten.t.default(t_2); t_2 = None
permute_3: "f32[4, 4]" = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None
# File: /home/xmfan/core/pytorch/torch/_dynamo/polyfill.py:44 in accumulate_grad, code: new_grad = torch.clone(new_grad)
clone_1: "f32[4, 4]" = torch.ops.aten.clone.default(permute_3); permute_3 = None
# File: <eval_with_key>.0:27 in forward, code: threshold_backward_1 = torch.ops.aten.threshold_backward.default(mm, getitem_5, 0); mm = getitem_5 = None
le_1: "b8[2, 4]" = torch.ops.aten.le.Scalar(arg4_1, 0); arg4_1 = None
scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
where_1: "f32[2, 4]" = torch.ops.aten.where.self(le_1, scalar_tensor_1, mm); le_1 = scalar_tensor_1 = mm = None
# File: <eval_with_key>.0:28 in forward, code: t_4 = torch.ops.aten.t.default(threshold_backward_1)
permute_4: "f32[4, 2]" = torch.ops.aten.permute.default(where_1, [1, 0])
# File: <eval_with_key>.0:29 in forward, code: mm_2 = torch.ops.aten.mm.default(t_4, getitem_6); t_4 = getitem_6 = None
mm_2: "f32[4, 4]" = torch.ops.aten.mm.default(permute_4, arg5_1); permute_4 = arg5_1 = None
# File: <eval_with_key>.0:30 in forward, code: t_5 = torch.ops.aten.t.default(mm_2); mm_2 = None
permute_5: "f32[4, 4]" = torch.ops.aten.permute.default(mm_2, [1, 0]); mm_2 = None
# File: <eval_with_key>.0:31 in forward, code: sum_2 = torch.ops.aten.sum.dim_IntList(threshold_backward_1, [0], True); threshold_backward_1 = None
sum_2: "f32[1, 4]" = torch.ops.aten.sum.dim_IntList(where_1, [0], True); where_1 = None
# File: <eval_with_key>.0:32 in forward, code: view_1 = torch.ops.aten.view.default(sum_2, [4]); sum_2 = None
view_1: "f32[4]" = torch.ops.aten.view.default(sum_2, [4]); sum_2 = None
# File: /home/xmfan/core/pytorch/torch/_dynamo/polyfill.py:44 in accumulate_grad, code: new_grad = torch.clone(new_grad)
clone_2: "f32[4]" = torch.ops.aten.clone.default(view_1); view_1 = None
# File: <eval_with_key>.0:34 in forward, code: t_6 = torch.ops.aten.t.default(t_5); t_5 = None
permute_6: "f32[4, 4]" = torch.ops.aten.permute.default(permute_5, [1, 0]); permute_5 = None
# File: /home/xmfan/core/pytorch/torch/_dynamo/polyfill.py:44 in accumulate_grad, code: new_grad = torch.clone(new_grad)
clone_3: "f32[4, 4]" = torch.ops.aten.clone.default(permute_6); permute_6 = None
return (clone_1, clone_3, clone_2, clone)
DEBUG:torch._inductor.graph.__output_code:Output code:
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
cpp_fused_threshold_backward_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'float*'], '''
#include "/tmp/torchinductor_xmfan/np/cnpfagnjbuwis32i7j7u7gflhlxcn7ws2mtrujf26hxyo6pvmx6t.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
float* out_ptr0)
{
{
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(8L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp3 = in_ptr1[static_cast<long>(0L)];
auto tmp1 = static_cast<float>(0.0);
auto tmp2 = tmp0 <= tmp1;
auto tmp4 = tmp2 ? tmp1 : tmp3;
out_ptr0[static_cast<long>(x0)] = tmp4;
}
}
}
''')
cpp_fused_threshold_backward_1 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_xmfan/np/cnpfagnjbuwis32i7j7u7gflhlxcn7ws2mtrujf26hxyo6pvmx6t.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(8L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp3 = in_out_ptr0[static_cast<long>(x0)];
auto tmp1 = static_cast<float>(0.0);
auto tmp2 = tmp0 <= tmp1;
auto tmp4 = tmp2 ? tmp1 : tmp3;
in_out_ptr0[static_cast<long>(x0)] = tmp4;
}
}
}
''')
cpp_fused_sum_2 = async_compile.cpp_pybinding(['const float*', 'const float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_xmfan/np/cnpfagnjbuwis32i7j7u7gflhlxcn7ws2mtrujf26hxyo6pvmx6t.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
float* out_ptr0,
float* out_ptr1)
{
{
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp1 = in_ptr0[static_cast<long>(4L + x0)];
auto tmp2 = decltype(tmp0)(tmp0 + tmp1);
out_ptr0[static_cast<long>(x0)] = tmp2;
}
}
{
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr1[static_cast<long>(x0)];
auto tmp1 = in_ptr1[static_cast<long>(4L + x0)];
auto tmp2 = decltype(tmp0)(tmp0 + tmp1);
out_ptr1[static_cast<long>(x0)] = tmp2;
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
args.clear()
assert_size_stride(arg0_1, (), ())
assert_size_stride(arg1_1, (2, 4), (4, 1))
assert_size_stride(arg2_1, (2, 4), (4, 1))
assert_size_stride(arg3_1, (4, 4), (1, 4))
assert_size_stride(arg4_1, (2, 4), (4, 1))
assert_size_stride(arg5_1, (2, 4), (4, 1))
buf0 = empty_strided_cpu((2, 4), (4, 1), torch.float32)
cpp_fused_threshold_backward_0(arg1_1, arg0_1, buf0)
del arg0_1
del arg1_1
buf1 = empty_strided_cpu((4, 4), (4, 1), torch.float32)
# Source Nodes: [mm_1], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf0, (4, 2), (1, 4), 0), arg2_1, out=buf1)
del arg2_1
buf2 = empty_strided_cpu((2, 4), (4, 1), torch.float32)
# Source Nodes: [mm], Original ATen: [aten.mm]
extern_kernels.mm(buf0, reinterpret_tensor(arg3_1, (4, 4), (4, 1), 0), out=buf2)
del arg3_1
buf3 = buf2; del buf2 # reuse
cpp_fused_threshold_backward_1(buf3, arg4_1)
del arg4_1
buf4 = empty_strided_cpu((4, 4), (4, 1), torch.float32)
# Source Nodes: [mm_2], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf3, (4, 2), (1, 4), 0), arg5_1, out=buf4)
del arg5_1
buf5 = empty_strided_cpu((1, 4), (4, 1), torch.float32)
buf6 = empty_strided_cpu((1, 4), (4, 1), torch.float32)
cpp_fused_sum_2(buf3, buf0, buf5, buf6)
return (reinterpret_tensor(buf1, (4, 4), (4, 1), 0), reinterpret_tensor(buf4, (4, 4), (4, 1), 0), reinterpret_tensor(buf5, (4, ), (1, ), 0), reinterpret_tensor(buf6, (4, ), (1, ), 0), )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((), (), device='cpu', dtype=torch.float32)
arg1_1 = rand_strided((2, 4), (4, 1), device='cpu', dtype=torch.float32)
arg2_1 = rand_strided((2, 4), (4, 1), device='cpu', dtype=torch.float32)
arg3_1 = rand_strided((4, 4), (1, 4), device='cpu', dtype=torch.float32)
arg4_1 = rand_strided((2, 4), (4, 1), device='cpu', dtype=torch.float32)
arg5_1 = rand_strided((2, 4), (4, 1), device='cpu', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
INFO:torch._inductor.graph.__output_code:Output code written to: /tmp/torchinductor_xmfan/sr/csrajnecdck7a5hi27h53aarr6des3ltlpfdtj7hhfmqfk45rd7x.py
.
----------------------------------------------------------------------
Ran 1 test in 0.524s
OK
compiled_autograd [('captures', 1), ('compiles', 1)]
inline_call []
stats [('calls_captured', 21), ('unique_graphs', 1)]
inductor [('pattern_matcher_count', 2), ('pattern_matcher_nodes', 2), ('fxgraph_cache_miss', 1)]
aot_autograd [('total', 1), ('ok', 1)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment