Skip to content

Instantly share code, notes, and snippets.

@zhuhaozhe
Created April 17, 2024 06:08
Show Gist options
  • Save zhuhaozhe/d99974f84d78f76f393fbeec7ee3b03e to your computer and use it in GitHub Desktop.
Save zhuhaozhe/d99974f84d78f76f393fbeec7ee3b03e to your computer and use it in GitHub Desktop.
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
cpp_fused_eq_index_masked_fill_mul_new_zeros_ones_pow_scatter_add_0 = async_compile.cpp_pybinding(['const int64_t*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_haozhe/kn/cknhgwrjj77u7ocflnojvr6pbo5cbrerqim3blda2rcpzbh4gnow.h"
extern "C" void kernel(const int64_t* in_ptr0,
float* out_ptr0,
float* out_ptr1)
{
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(10000L); x0+=static_cast<long>(16L))
{
auto tmp0 = static_cast<float>(0.0);
auto tmp1 = at::vec::Vectorized<float>(tmp0);
tmp1.store(out_ptr0 + static_cast<long>(x0));
}
}
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(209985L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(209985L + x0)];
auto tmp1 = decltype(tmp0)(tmp0 + 10000);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 10000L), "index out of bounds: 0 <= tmp3 < 10000L")
auto tmp4 = static_cast<float>(1.0);
out_ptr0[static_cast<long>(tmp3)] += tmp4;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(209984L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::VectorizedN<int64_t,2>::loadu(in_ptr0 + static_cast<long>(x0));
auto tmp22 = at::vec::VectorizedN<int64_t,2>::loadu(in_ptr0 + static_cast<long>(209985L + x0));
auto tmp1 = static_cast<int64_t>(10000);
auto tmp2 = at::vec::VectorizedN<int64_t,2>(tmp1);
auto tmp3 = tmp0 + tmp2;
auto tmp4 = static_cast<int64_t>(0);
auto tmp5 = at::vec::VectorizedN<int64_t,2>(tmp4);
auto tmp6 = at::vec::VecMask<int64_t,2>(tmp0 < tmp5);
auto tmp7 = decltype(tmp3)::blendv(tmp0, tmp3, tmp6.template cast<int64_t,2>());
TORCH_CHECK((at::vec::VecMask<int64_t,2>((at::vec::VectorizedN<int64_t,2>(0) <= tmp7) & (tmp7 < at::vec::VectorizedN<int64_t,2>(10000L)))).all_masked(), "index out of bounds: 0 <= tmp7 < 10000L")
auto tmp8 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
tmp7.store(tmpbuf.data());
return tmpbuf;
}
()
;
auto tmp9 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x0_inner = 0; x0_inner < 16; x0_inner++)
{
tmpbuf[x0_inner] = out_ptr0[static_cast<long>(tmp8[x0_inner])];
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data());
}
()
;
auto tmp10 = static_cast<float>(-0.5);
auto tmp11 = at::vec::Vectorized<float>(tmp10);
auto tmp12 = tmp9.pow(tmp11);
auto tmp13 = std::numeric_limits<float>::infinity();
auto tmp14 = at::vec::Vectorized<float>(tmp13);
auto tmp15 = at::vec::VecMask<float,1>(tmp12 == tmp14);
auto tmp16 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
tmp7.store(tmpbuf.data());
return tmpbuf;
}
()
;
auto tmp17 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x0_inner = 0; x0_inner < 16; x0_inner++)
{
tmpbuf[x0_inner] = out_ptr0[static_cast<long>(tmp16[x0_inner])];
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data());
}
()
;
auto tmp18 = tmp17.pow(tmp11);
auto tmp19 = static_cast<float>(0.0);
auto tmp20 = at::vec::Vectorized<float>(tmp19);
auto tmp21 = decltype(tmp20)::blendv(tmp18, tmp20, tmp15.template cast<float,1>());
auto tmp23 = tmp22 + tmp2;
auto tmp24 = at::vec::VecMask<int64_t,2>(tmp22 < tmp5);
auto tmp25 = decltype(tmp23)::blendv(tmp22, tmp23, tmp24.template cast<int64_t,2>());
TORCH_CHECK((at::vec::VecMask<int64_t,2>((at::vec::VectorizedN<int64_t,2>(0) <= tmp25) & (tmp25 < at::vec::VectorizedN<int64_t,2>(10000L)))).all_masked(), "index out of bounds: 0 <= tmp25 < 10000L")
auto tmp26 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
tmp25.store(tmpbuf.data());
return tmpbuf;
}
()
;
auto tmp27 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x0_inner = 0; x0_inner < 16; x0_inner++)
{
tmpbuf[x0_inner] = out_ptr0[static_cast<long>(tmp26[x0_inner])];
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data());
}
()
;
auto tmp28 = tmp27.pow(tmp11);
auto tmp29 = at::vec::VecMask<float,1>(tmp28 == tmp14);
auto tmp30 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
tmp25.store(tmpbuf.data());
return tmpbuf;
}
()
;
auto tmp31 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x0_inner = 0; x0_inner < 16; x0_inner++)
{
tmpbuf[x0_inner] = out_ptr0[static_cast<long>(tmp30[x0_inner])];
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data());
}
()
;
auto tmp32 = tmp31.pow(tmp11);
auto tmp33 = decltype(tmp20)::blendv(tmp32, tmp20, tmp29.template cast<float,1>());
auto tmp34 = tmp21 * tmp33;
tmp34.store(out_ptr1 + static_cast<long>(x0));
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(209984L); x0<static_cast<long>(209985L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp11 = in_ptr0[static_cast<long>(209985L + x0)];
auto tmp1 = decltype(tmp0)(tmp0 + 10000);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 10000L), "index out of bounds: 0 <= tmp3 < 10000L")
auto tmp4 = out_ptr0[static_cast<long>(tmp3)];
auto tmp5 = static_cast<float>(-0.5);
auto tmp6 = std::pow(tmp4, tmp5);
auto tmp7 = std::numeric_limits<float>::infinity();
auto tmp8 = tmp6 == tmp7;
auto tmp9 = static_cast<float>(0.0);
auto tmp10 = tmp8 ? tmp9 : tmp6;
auto tmp12 = decltype(tmp11)(tmp11 + 10000);
auto tmp13 = tmp11 < 0;
auto tmp14 = tmp13 ? tmp12 : tmp11;
TORCH_CHECK((0 <= tmp14) & (tmp14 < 10000L), "index out of bounds: 0 <= tmp14 < 10000L")
auto tmp15 = out_ptr0[static_cast<long>(tmp14)];
auto tmp16 = std::pow(tmp15, tmp5);
auto tmp17 = tmp16 == tmp7;
auto tmp18 = tmp17 ? tmp9 : tmp16;
auto tmp19 = decltype(tmp10)(tmp10 * tmp18);
out_ptr1[static_cast<long>(x0)] = tmp19;
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (2, 209985), (209985, 1))
buf0 = empty_strided_cpu((10000, ), (1, ), torch.float32)
buf2 = empty_strided_cpu((209985, ), (1, ), torch.float32)
cpp_fused_eq_index_masked_fill_mul_new_zeros_ones_pow_scatter_add_0(arg0_1, buf0, buf2)
del arg0_1
return (buf2, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((2, 209985), (209985, 1), device='cpu', dtype=torch.int64)
fn = lambda: call([arg0_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('basic_gnn_gcn', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment