Skip to content

Instantly share code, notes, and snippets.

@xmfan
Last active February 3, 2024 05:29
Show Gist options
  • Save xmfan/f66eebca61fea1087ec892910cc251eb to your computer and use it in GitHub Desktop.
Save xmfan/f66eebca61fea1087ec892910cc251eb to your computer and use it in GitHub Desktop.
output code from TORCH_LOGS="output_code" python simple.py, with inputs baked in
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
cpp_fused_mm_0 = async_compile.cpp_pybinding(['const float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_xmfan/lg/clghje745biezhrbrw5fghxqjaj76ck5jms7466s4ax63eruswf5.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0,
float* out_ptr1)
{
{
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(0L)];
out_ptr0[static_cast<long>(x0)] = tmp0;
out_ptr1[static_cast<long>(x0)] = tmp0;
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
args.clear()
assert_size_stride(arg0_1, (), ())
assert_size_stride(arg1_1, (1, 2), (1, 1))
assert_size_stride(arg2_1, (2, 1), (1, 1))
assert_size_stride(arg3_1, (2, 1), (1, 1))
assert_size_stride(arg4_1, (2, 2), (2, 1))
assert_size_stride(arg5_1, (1, 2), (2, 1))
buf0 = empty_strided_cpu((2, 2), (2, 1), torch.float32)
buf4 = empty_strided_cpu((2, 2), (2, 1), torch.float32)
cpp_fused_mm_0(arg0_1, buf0, buf4)
del arg0_1
buf1 = empty_strided_cpu((2, 1), (1, 1), torch.float32)
# Source Nodes: [mm], Original ATen: [aten.mm]
extern_kernels.mm(buf0, arg2_1, out=buf1)
del arg2_1
del buf0
inductor_ops.accumulate_grad_(arg3_1, reinterpret_tensor(buf1, (2, 1), (1, 1), 0))
del arg3_1
buf5 = buf1; del buf1 # reuse
# Source Nodes: [mm_1], Original ATen: [aten.mm]
extern_kernels.mm(buf4, reinterpret_tensor(arg1_1, (2, 1), (1, 1), 0), out=buf5)
del arg1_1
del buf4
buf6 = empty_strided_cpu((1, 2), (2, 1), torch.float32)
# Source Nodes: [mm_2], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf5, (1, 2), (1, 1), 0), arg4_1, out=buf6)
del arg4_1
del buf5
inductor_ops.accumulate_grad_(arg5_1, reinterpret_tensor(buf6, (1, 2), (2, 1), 0))
del buf6
del arg5_1
return ()
# def benchmark_compiled_module(times=10, repeat=10):
def benchmark_compiled_module():
# from torch._dynamo.testing import rand_strided
# from torch._inductor.utils import print_performance
# arg0_1 = rand_strided((), (), device='cpu', dtype=torch.float32)
# arg1_1 = rand_strided((1, 2), (1, 1), device='cpu', dtype=torch.float32)
# arg2_1 = rand_strided((2, 1), (1, 1), device='cpu', dtype=torch.float32)
# arg3_1 = rand_strided((2, 1), (1, 1), device='cpu', dtype=torch.float32)
# arg4_1 = rand_strided((2, 2), (2, 1), device='cpu', dtype=torch.float32)
# arg5_1 = rand_strided((1, 2), (2, 1), device='cpu', dtype=torch.float32)
arg0_1 = torch.tensor(1.)
arg1_1 = torch.tensor([[-0.8230, -0.7359]])
arg2_1 = torch.tensor([[ 0.2271], [-0.5247]])
arg3_1 = torch.nn.Parameter(torch.tensor([[-0.8230],[-0.7359]]))
arg4_1 = torch.tensor([[-2.1788, 0.5684], [-1.0845, -1.3986]])
arg5_1 = torch.nn.Parameter(torch.tensor([[-0.0053, 0.3793]]))
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1])
torch.manual_seed(0)
fn()
print(arg5_1.grad)
# tensor([[5.0872, 1.2942]])
print(arg3_1.grad)
# tensor([[-0.2976], [-0.2976]])
# return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
# from torch._inductor.wrapper_benchmark import compiled_module_main
# compiled_module_main('None', benchmark_compiled_module)
benchmark_compiled_module()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment