Skip to content

Instantly share code, notes, and snippets.

@simveit
Created May 19, 2025 20:14
Show Gist options
  • Save simveit/ab0a28efb4338592f82c0a8f762f0ac7 to your computer and use it in GitHub Desktop.
Save simveit/ab0a28efb4338592f82c0a8f762f0ac7 to your computer and use it in GitHub Desktop.
import cutlass
import cutlass.cute as cute
import torch
from cutlass.torch import dtype as torch_dtype
import cutlass.cute.runtime as cute_rt
@cute.kernel
def transpose_naive_kernel(tiled_tensor_src, tiled_tensor_dst, ThreadLayout):
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, _ = cute.arch.block_idx()
block_coordinate = ((None, None), bidx, bidy)
global_tile_src = tiled_tensor_src[block_coordinate]
global_tile_dst_transposed = tiled_tensor_dst[block_coordinate]
thread_global_tile_src = cute.local_partition(
global_tile_src, tiler=ThreadLayout, index=tidx
)
if tidx == 0 and bidx == 0 and bidy == 0:
cute.printf(" >> block_coordinate = {}", block_coordinate)
cute.printf(" >> global_tile_src layout = {}", global_tile_src.layout)
cute.printf(
" >> global_tile_src_transposed layout = {}",
global_tile_dst_transposed.layout,
)
cute.printf(" >> ThreadLayout = {}", ThreadLayout)
@cute.jit
def transpose_naive(
M: cutlass.Constexpr,
N: cutlass.Constexpr,
ptr_a: cute.Pointer,
ptr_b: cute.Pointer,
coalesced_read: bool,
):
# Create GMEM layouts
tensor_shape = (M, N)
tensor_shape_transposed = (N, M)
global_memory_layout_src = cute.make_layout(
tensor_shape, stride=(N, 1)
) # (M, N) : (N, 1)
global_memory_layout_dst = cute.make_layout(
tensor_shape_transposed, stride=(M, 1)
) # (N, M) : (M, 1)
global_memory_layout_dst_transposed = cute.make_layout(
tensor_shape_transposed, stride=(1, M)
) # (M, N) : (M, 1)
# Create tensors
tensor_src = cute.make_tensor(ptr_a, global_memory_layout_src)
tensor_dst = cute.make_tensor(ptr_b, global_memory_layout_dst)
tensor_dst_transposed = cute.make_tensor(ptr_b, global_memory_layout_dst_transposed)
# Block Tiling
TileSizeX = 64 # bN
TileSizeY = 32 # bM
block_shape = (TileSizeY, TileSizeX)
tiled_tensor_src = cute.tiled_divide(
tensor_src, block_shape
) # ((TileSizeY, TileSizeX), M/TileSizeY, N/TileSizeX)
tiled_tensor_dst_transposed = cute.tiled_divide(
tensor_dst_transposed, block_shape
) # ((TileSizeY, TileSizeX), M/TileSizeY, N/TileSizeX)
# Thread Tiling
ThreadBlockSizeX = 32 # tN
ThreadBlockSizeY = 8 # tM
thread_block_shape = (ThreadBlockSizeY, ThreadBlockSizeX)
thread_block_shape_transposed = (ThreadBlockSizeX, ThreadBlockSizeY)
thread_layout = cute.make_layout(
thread_block_shape, stride=(ThreadBlockSizeX, 1)
) # (tM, tN) : (tN, 1)
thread_layout_transposed = cute.make_layout(
thread_block_shape_transposed, stride=(1, ThreadBlockSizeX)
) # (tN, tM) : (1, tN)
# print the layouts
cute.printf(
"""Overview:
>>> M = {}, N = {}, bM = {}, bN = {}
>>> global_memory_layout_src = {}
>>> global_memory_layout_dst = {}
>>> global_memory_layout_dst_transposed = {}
>>> tensor_src layout = {}
>>> tensor_dst layout = {}
>>> tensor_dst_transposed layout = {}
>>> tiled_tensor_src layout = {}
>>> tiled_tensor_dst_transposed layout = {}
>>> thread_layout = {}
>>> thread_layout_transposed = {}""",
M,
N,
TileSizeX,
TileSizeY,
global_memory_layout_src,
global_memory_layout_dst,
global_memory_layout_dst_transposed,
tensor_src.layout,
tensor_dst.layout,
tensor_dst_transposed.layout,
tiled_tensor_src.layout,
tiled_tensor_dst_transposed.layout,
thread_layout,
thread_layout_transposed,
)
if coalesced_read:
transpose_naive_kernel(
tiled_tensor_src, tiled_tensor_dst_transposed, thread_layout
).launch(
grid=[
cute.size(tiled_tensor_dst_transposed, mode=[2]),
cute.size(tiled_tensor_dst_transposed, mode=[1]),
1,
],
block=[ThreadBlockSizeX * ThreadBlockSizeY, 1, 1],
)
else:
transpose_naive_kernel(
tiled_tensor_src, tiled_tensor_dst_transposed, thread_layout_transposed
).launch(
grid=[
cute.size(tiled_tensor_dst_transposed, mode=[2]),
cute.size(tiled_tensor_dst_transposed, mode=[1]),
1,
],
block=[ThreadBlockSizeX * ThreadBlockSizeY, 1, 1],
)
M, N = 2048, 4096
a = torch.randn(M, N, dtype=torch_dtype(cutlass.Float32), device="cuda")
ptr_a = cute_rt.make_ptr(cutlass.Float32, a.data_ptr())
b = torch.randn(M, N, dtype=torch_dtype(cutlass.Float32), device="cuda")
ptr_b = cute_rt.make_ptr(cutlass.Float32, b.data_ptr())
print("Launch with coalesced read")
transpose_naive(M, N, ptr_a, ptr_b, True)
print("Launch with coalesced write")
transpose_naive(M, N, ptr_a, ptr_b, False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment