Skip to content

Instantly share code, notes, and snippets.

@zhuhaozhe
Last active April 17, 2024 06:09
Show Gist options
  • Save zhuhaozhe/408469f904e644fa2ea6c3c5ac013453 to your computer and use it in GitHub Desktop.
Save zhuhaozhe/408469f904e644fa2ea6c3c5ac013453 to your computer and use it in GitHub Desktop.
scalar.py
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
cpp_fused_eq_index_masked_fill_mul_new_zeros_ones_pow_scatter_add_0 = async_compile.cpp_pybinding(['const int64_t*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_haozhe/kn/cknhgwrjj77u7ocflnojvr6pbo5cbrerqim3blda2rcpzbh4gnow.h"
extern "C" void kernel(const int64_t* in_ptr0,
float* out_ptr0,
float* out_ptr1)
{
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(10000L); x0+=static_cast<long>(16L))
{
auto tmp0 = static_cast<float>(0.0);
auto tmp1 = at::vec::Vectorized<float>(tmp0);
tmp1.store(out_ptr0 + static_cast<long>(x0));
}
}
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(209985L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(209985L + x0)];
auto tmp1 = decltype(tmp0)(tmp0 + 10000);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 10000L), "index out of bounds: 0 <= tmp3 < 10000L")
auto tmp4 = static_cast<float>(1.0);
out_ptr0[static_cast<long>(tmp3)] += tmp4;
}
}
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(209985L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp11 = in_ptr0[static_cast<long>(209985L + x0)];
auto tmp1 = decltype(tmp0)(tmp0 + 10000);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 10000L), "index out of bounds: 0 <= tmp3 < 10000L")
auto tmp4 = out_ptr0[static_cast<long>(tmp3)];
auto tmp5 = static_cast<float>(-0.5);
auto tmp6 = std::pow(tmp4, tmp5);
auto tmp7 = std::numeric_limits<float>::infinity();
auto tmp8 = tmp6 == tmp7;
auto tmp9 = static_cast<float>(0.0);
auto tmp10 = tmp8 ? tmp9 : tmp6;
auto tmp12 = decltype(tmp11)(tmp11 + 10000);
auto tmp13 = tmp11 < 0;
auto tmp14 = tmp13 ? tmp12 : tmp11;
TORCH_CHECK((0 <= tmp14) & (tmp14 < 10000L), "index out of bounds: 0 <= tmp14 < 10000L")
auto tmp15 = out_ptr0[static_cast<long>(tmp14)];
auto tmp16 = std::pow(tmp15, tmp5);
auto tmp17 = tmp16 == tmp7;
auto tmp18 = tmp17 ? tmp9 : tmp16;
auto tmp19 = decltype(tmp10)(tmp10 * tmp18);
out_ptr1[static_cast<long>(x0)] = tmp19;
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (2, 209985), (209985, 1))
buf0 = empty_strided_cpu((10000, ), (1, ), torch.float32)
buf2 = empty_strided_cpu((209985, ), (1, ), torch.float32)
cpp_fused_eq_index_masked_fill_mul_new_zeros_ones_pow_scatter_add_0(arg0_1, buf0, buf2)
del arg0_1
return (buf2, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((2, 209985), (209985, 1), device='cpu', dtype=torch.int64)
fn = lambda: call([arg0_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('basic_gnn_gcn', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment