Skip to content

Instantly share code, notes, and snippets.

@fishmingyu
Last active December 8, 2023 10:31
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fishmingyu/60910e86411c60884bdeec3878dfeda3 to your computer and use it in GitHub Desktop.
Save fishmingyu/60910e86411c60884bdeec3878dfeda3 to your computer and use it in GitHub Desktop.
Gather Scatter fusion in PyG by Inductor
import torch
import torch_geometric
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
disableExtensions,
onlyFullTest,
onlyLinux,
withCUDA,
withPackage,
)
from torch_geometric.utils import scatter
# Basic "Gather-Apply-Scatter" patterns commonly used in PyG:
def gather_scatter(x, edge_index, reduce="sum"):
row, col = edge_index
x_j = x[row]
return scatter(x_j, col, dim_size=x.size(0), reduce=reduce)
@onlyLinux
@onlyFullTest
@disableExtensions
@withPackage("torch>=2.0.0")
def test_torch_compile(device):
x = torch.randn(10, 16, device=device)
edge_index = torch.randint(0, x.size(0), (2, 40), device=device)
edge_weight = torch.rand(edge_index.size(1), device=device)
matrix = torch.randn(x.size(-1), x.size(-1), device=device)
expected = gather_scatter(x, edge_index)
compiled_op = torch_geometric.compile(gather_scatter)
out = compiled_op(x, edge_index)
assert torch.allclose(out, expected, atol=1e-6)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--backward", action="store_true")
args = parser.parse_args()
num_nodes, num_edges = 10_000, 200_000
feature_size = 32
x = torch.randn(num_nodes, feature_size, device=args.device)
edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)
edge_weight = torch.rand(num_edges, device=args.device)
matrix = torch.randn(feature_size, feature_size, device=args.device)
compiled_func = torch_geometric.compile(gather_scatter, backend="inductor")
compiled_func(x, edge_index)
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch import empty_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
cpp_fused_index_new_zeros_scatter_add_0 = async_compile.cpp('''
#include "/tmp/torchinductor_hanxian/ib/cibrnuq56cxamjj4krp4zpjvsirbmlolpbnmomodzyd46huzhdw7.h"
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
float* out_ptr0)
{
#pragma omp parallel num_threads(8)
{
{
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(320000L); i0+=static_cast<long>(8L))
{
auto tmp0 = at::vec::Vectorized<float>(static_cast<float>(0.0));
tmp0.store(out_ptr0 + static_cast<long>(i0));
}
}
{
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(200000L); i0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long i1=static_cast<long>(0L); i1<static_cast<long>(32L); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(200000L + i0)];
auto tmp1 = in_ptr0[static_cast<long>(i0)];
auto tmp2 = in_ptr1[static_cast<long>(i1 + (32L*tmp1))];
atomic_add(&out_ptr0[static_cast<long>(i1 + (32L*tmp0))], tmp2);
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (10000, 32), (32, 1))
assert_size_stride(arg1_1, (2, 200000), (200000, 1))
buf0 = empty_strided((10000, 32), (32, 1), device='cpu', dtype=torch.float32)
cpp_fused_index_new_zeros_scatter_add_0(c_void_p(arg1_1.data_ptr()), c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
del arg0_1
del arg1_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((10000, 32), (32, 1), device='cpu', dtype=torch.float32)
arg1_1 = rand_strided((2, 200000), (200000, 1), device='cpu', dtype=torch.int64)
return print_performance(lambda: call([arg0_1, arg1_1]), times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment