Skip to content

Instantly share code, notes, and snippets.

@fishmingyu
Created November 27, 2023 02:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fishmingyu/36d7941d64b5d43de928ec5dc6012e6a to your computer and use it in GitHub Desktop.
Save fishmingyu/36d7941d64b5d43de928ec5dc6012e6a to your computer and use it in GitHub Desktop.
Sorted COO of SpMM in triton
import time
import triton
import triton.language as tl
import torch
@triton.jit
def spmm_atomic(edge_index, B, C, num_edges, feature_size: tl.constexpr, XBLOCK: tl.constexpr):
group_id = tl.program_id(0)
xoffset = group_id * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)
x1 = xindex // feature_size
x2 = xindex % feature_size
mask = x1 < num_edges
in_node = tl.load(edge_index + x1, mask)
out_node = tl.load(edge_index + x1 + num_edges, mask)
in_val = tl.load(B + in_node * feature_size + x2, mask)
tl.atomic_add(C + out_node * feature_size + x2, in_val, mask)
@triton.jit
def spmm_sorted_coo_naive(edge_index, B, C, num_edges, feature_size: tl.constexpr, group_size: tl.constexpr):
group_id = tl.program_id(0)
node_offset = group_id * group_size
f_index = tl.arange(0, feature_size)
xn = node_offset
mask = xn < num_edges
in_node = tl.load(edge_index + xn, mask=mask) # Load the input node
out_node = tl.load(edge_index + xn + num_edges,
mask=mask) # Load the output node
curr_node = out_node
val = tl.load(B + in_node * feature_size + f_index, mask=mask)
for ii in range(1, group_size): # Iterate over the group
xn = ii + node_offset # Get the node index
mask = xn < num_edges # Check if the node index is valid
in_node = tl.load(edge_index + xn, mask=mask) # Load the input node
out_node = tl.load(edge_index + xn + num_edges,
mask=mask) # Load the output node
new_val = tl.load(B + in_node * feature_size + f_index, mask=mask)
if out_node != curr_node:
# Perform atomic addition
tl.atomic_add(C + curr_node * feature_size +
f_index, val, mask=mask)
# Reset val for the new row
val = new_val
curr_node = out_node
else:
# Accumulate val
val += new_val
tl.atomic_add(C + out_node * feature_size + f_index, val, mask=mask)
num_nodes, num_edges = 10_000, 200_000
features = 64
edge_index = torch.randint(num_nodes, (2, num_edges), device="cuda")
# Transpose the tensor to bring the second row to the column dimension
edge_index_transposed = edge_index.t()
# Sort the tensor along the column dimension (formerly the second row)
sorted_edge_index_transposed, _ = edge_index_transposed.sort(dim=0)
# Transpose the tensor back to its original shape
sorted_edge_index = sorted_edge_index_transposed.t()
B = torch.rand(num_nodes, features, device="cuda").to(torch.float32)
C_atomic = torch.zeros(num_nodes, features, device="cuda").to(torch.float32)
C_sorted = torch.zeros(num_nodes, features, device="cuda").to(torch.float32)
group_size = 50
grid_atomic = (triton.cdiv(num_edges * features, 128), )
spmm_atomic[grid_atomic](sorted_edge_index, B,
C_atomic, num_edges, features, 128)
grid_sorted = (triton.cdiv(num_edges, group_size), )
spmm_sorted_coo_naive[grid_sorted](
sorted_edge_index, B, C_sorted, num_edges, features, group_size)
# compare the result
print(torch.allclose(C_atomic, C_sorted))
# test performance
torch.cuda.synchronize()
start = time.time()
for i in range(100):
spmm_atomic[grid_atomic](sorted_edge_index, B,
C_atomic, num_edges, features, 128)
torch.cuda.synchronize()
end = time.time()
print("atomic time: ", end - start)
torch.cuda.synchronize()
start = time.time()
for i in range(100):
spmm_sorted_coo_naive[grid_sorted](
sorted_edge_index, B, C_sorted, num_edges, features, group_size)
torch.cuda.synchronize()
end = time.time()
print("sorted time: ", end - start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment