Created
November 27, 2023 02:00
-
-
Save fishmingyu/36d7941d64b5d43de928ec5dc6012e6a to your computer and use it in GitHub Desktop.
Sorted COO of SpMM in triton
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import triton | |
import triton.language as tl | |
import torch | |
@triton.jit | |
def spmm_atomic(edge_index, B, C, num_edges, feature_size: tl.constexpr, XBLOCK: tl.constexpr): | |
group_id = tl.program_id(0) | |
xoffset = group_id * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK) | |
x1 = xindex // feature_size | |
x2 = xindex % feature_size | |
mask = x1 < num_edges | |
in_node = tl.load(edge_index + x1, mask) | |
out_node = tl.load(edge_index + x1 + num_edges, mask) | |
in_val = tl.load(B + in_node * feature_size + x2, mask) | |
tl.atomic_add(C + out_node * feature_size + x2, in_val, mask) | |
@triton.jit | |
def spmm_sorted_coo_naive(edge_index, B, C, num_edges, feature_size: tl.constexpr, group_size: tl.constexpr): | |
group_id = tl.program_id(0) | |
node_offset = group_id * group_size | |
f_index = tl.arange(0, feature_size) | |
xn = node_offset | |
mask = xn < num_edges | |
in_node = tl.load(edge_index + xn, mask=mask) # Load the input node | |
out_node = tl.load(edge_index + xn + num_edges, | |
mask=mask) # Load the output node | |
curr_node = out_node | |
val = tl.load(B + in_node * feature_size + f_index, mask=mask) | |
for ii in range(1, group_size): # Iterate over the group | |
xn = ii + node_offset # Get the node index | |
mask = xn < num_edges # Check if the node index is valid | |
in_node = tl.load(edge_index + xn, mask=mask) # Load the input node | |
out_node = tl.load(edge_index + xn + num_edges, | |
mask=mask) # Load the output node | |
new_val = tl.load(B + in_node * feature_size + f_index, mask=mask) | |
if out_node != curr_node: | |
# Perform atomic addition | |
tl.atomic_add(C + curr_node * feature_size + | |
f_index, val, mask=mask) | |
# Reset val for the new row | |
val = new_val | |
curr_node = out_node | |
else: | |
# Accumulate val | |
val += new_val | |
tl.atomic_add(C + out_node * feature_size + f_index, val, mask=mask) | |
num_nodes, num_edges = 10_000, 200_000 | |
features = 64 | |
edge_index = torch.randint(num_nodes, (2, num_edges), device="cuda") | |
# Transpose the tensor to bring the second row to the column dimension | |
edge_index_transposed = edge_index.t() | |
# Sort the tensor along the column dimension (formerly the second row) | |
sorted_edge_index_transposed, _ = edge_index_transposed.sort(dim=0) | |
# Transpose the tensor back to its original shape | |
sorted_edge_index = sorted_edge_index_transposed.t() | |
B = torch.rand(num_nodes, features, device="cuda").to(torch.float32) | |
C_atomic = torch.zeros(num_nodes, features, device="cuda").to(torch.float32) | |
C_sorted = torch.zeros(num_nodes, features, device="cuda").to(torch.float32) | |
group_size = 50 | |
grid_atomic = (triton.cdiv(num_edges * features, 128), ) | |
spmm_atomic[grid_atomic](sorted_edge_index, B, | |
C_atomic, num_edges, features, 128) | |
grid_sorted = (triton.cdiv(num_edges, group_size), ) | |
spmm_sorted_coo_naive[grid_sorted]( | |
sorted_edge_index, B, C_sorted, num_edges, features, group_size) | |
# compare the result | |
print(torch.allclose(C_atomic, C_sorted)) | |
# test performance | |
torch.cuda.synchronize() | |
start = time.time() | |
for i in range(100): | |
spmm_atomic[grid_atomic](sorted_edge_index, B, | |
C_atomic, num_edges, features, 128) | |
torch.cuda.synchronize() | |
end = time.time() | |
print("atomic time: ", end - start) | |
torch.cuda.synchronize() | |
start = time.time() | |
for i in range(100): | |
spmm_sorted_coo_naive[grid_sorted]( | |
sorted_edge_index, B, C_sorted, num_edges, features, group_size) | |
torch.cuda.synchronize() | |
end = time.time() | |
print("sorted time: ", end - start) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment