Skip to content

Instantly share code, notes, and snippets.

@xmfan
Created March 26, 2024 22:42
Show Gist options
  • Save xmfan/398e43d4280456047c8d1a88073c7dd9 to your computer and use it in GitHub Desktop.
Save xmfan/398e43d4280456047c8d1a88073c7dd9 to your computer and use it in GitHub Desktop.
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
cpp_fused_add_clamp_min_fill_ne_sub_where_zeros_like_0 = async_compile.cpp_pybinding(['const float*', 'const half*', 'half*'], '''
#include "/tmp/torchinductor_xmfan/np/cnpfagnjbuwis32i7j7u7gflhlxcn7ws2mtrujf26hxyo6pvmx6t.h"
extern "C" void kernel(const float* in_ptr0,
const half* in_ptr1,
half* out_ptr0)
{
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4320L); x0+=static_cast<long>(1L))
{
#pragma omp simd simdlen(8)
for(long x1=static_cast<long>(0L); x1<static_cast<long>(8L); x1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x1 + (8L*x0))];
auto tmp3 = static_cast<float>(in_ptr1[static_cast<long>(x1)]);
auto tmp1 = static_cast<float>(1.0);
auto tmp2 = tmp0 != tmp1;
auto tmp4 = c10::convert<float>(tmp3);
auto tmp5 = static_cast<float>(7.2);
auto tmp6 = decltype(tmp5)(tmp5 - tmp4);
auto tmp7 = static_cast<float>(0.0);
auto tmp8 = max_propagate_nan(tmp6, tmp7);
auto tmp9 = tmp2 ? tmp8 : tmp7;
auto tmp10 = static_cast<float>(-1.0);
auto tmp11 = tmp0 != tmp10;
auto tmp12 = tmp11 ? tmp4 : tmp7;
auto tmp13 = decltype(tmp9)(tmp9 + tmp12);
auto tmp14 = c10::convert<half>(tmp13);
out_ptr0[static_cast<long>(x1 + (8L*x0))] = tmp14;
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (10, 9, 8, 6, 8), (3456, 384, 48, 8, 1))
assert_size_stride(arg1_1, (1, 1, 1, 1, 8), (8, 8, 8, 8, 1))
buf0 = empty_strided_cpu((10, 9, 8, 6, 8), (3456, 384, 48, 8, 1), torch.float16)
cpp_fused_add_clamp_min_fill_ne_sub_where_zeros_like_0(arg0_1, arg1_1, buf0)
del arg0_1
del arg1_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((10, 9, 8, 6, 8), (3456, 384, 48, 8, 1), device='cpu', dtype=torch.float32)
arg1_1 = rand_strided((1, 1, 1, 1, 8), (8, 8, 8, 8, 1), device='cpu', dtype=torch.float16)
fn = lambda: call([arg0_1, arg1_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment