Created
May 19, 2025 20:14
-
-
Save simveit/ab0a28efb4338592f82c0a8f762f0ac7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cutlass | |
import cutlass.cute as cute | |
import torch | |
from cutlass.torch import dtype as torch_dtype | |
import cutlass.cute.runtime as cute_rt | |
@cute.kernel | |
def transpose_naive_kernel(tiled_tensor_src, tiled_tensor_dst, ThreadLayout): | |
tidx, _, _ = cute.arch.thread_idx() | |
bidx, bidy, _ = cute.arch.block_idx() | |
block_coordinate = ((None, None), bidx, bidy) | |
global_tile_src = tiled_tensor_src[block_coordinate] | |
global_tile_dst_transposed = tiled_tensor_dst[block_coordinate] | |
thread_global_tile_src = cute.local_partition( | |
global_tile_src, tiler=ThreadLayout, index=tidx | |
) | |
if tidx == 0 and bidx == 0 and bidy == 0: | |
cute.printf(" >> block_coordinate = {}", block_coordinate) | |
cute.printf(" >> global_tile_src layout = {}", global_tile_src.layout) | |
cute.printf( | |
" >> global_tile_src_transposed layout = {}", | |
global_tile_dst_transposed.layout, | |
) | |
cute.printf(" >> ThreadLayout = {}", ThreadLayout) | |
@cute.jit | |
def transpose_naive( | |
M: cutlass.Constexpr, | |
N: cutlass.Constexpr, | |
ptr_a: cute.Pointer, | |
ptr_b: cute.Pointer, | |
coalesced_read: bool, | |
): | |
# Create GMEM layouts | |
tensor_shape = (M, N) | |
tensor_shape_transposed = (N, M) | |
global_memory_layout_src = cute.make_layout( | |
tensor_shape, stride=(N, 1) | |
) # (M, N) : (N, 1) | |
global_memory_layout_dst = cute.make_layout( | |
tensor_shape_transposed, stride=(M, 1) | |
) # (N, M) : (M, 1) | |
global_memory_layout_dst_transposed = cute.make_layout( | |
tensor_shape_transposed, stride=(1, M) | |
) # (M, N) : (M, 1) | |
# Create tensors | |
tensor_src = cute.make_tensor(ptr_a, global_memory_layout_src) | |
tensor_dst = cute.make_tensor(ptr_b, global_memory_layout_dst) | |
tensor_dst_transposed = cute.make_tensor(ptr_b, global_memory_layout_dst_transposed) | |
# Block Tiling | |
TileSizeX = 64 # bN | |
TileSizeY = 32 # bM | |
block_shape = (TileSizeY, TileSizeX) | |
tiled_tensor_src = cute.tiled_divide( | |
tensor_src, block_shape | |
) # ((TileSizeY, TileSizeX), M/TileSizeY, N/TileSizeX) | |
tiled_tensor_dst_transposed = cute.tiled_divide( | |
tensor_dst_transposed, block_shape | |
) # ((TileSizeY, TileSizeX), M/TileSizeY, N/TileSizeX) | |
# Thread Tiling | |
ThreadBlockSizeX = 32 # tN | |
ThreadBlockSizeY = 8 # tM | |
thread_block_shape = (ThreadBlockSizeY, ThreadBlockSizeX) | |
thread_block_shape_transposed = (ThreadBlockSizeX, ThreadBlockSizeY) | |
thread_layout = cute.make_layout( | |
thread_block_shape, stride=(ThreadBlockSizeX, 1) | |
) # (tM, tN) : (tN, 1) | |
thread_layout_transposed = cute.make_layout( | |
thread_block_shape_transposed, stride=(1, ThreadBlockSizeX) | |
) # (tN, tM) : (1, tN) | |
# print the layouts | |
cute.printf( | |
"""Overview: | |
>>> M = {}, N = {}, bM = {}, bN = {} | |
>>> global_memory_layout_src = {} | |
>>> global_memory_layout_dst = {} | |
>>> global_memory_layout_dst_transposed = {} | |
>>> tensor_src layout = {} | |
>>> tensor_dst layout = {} | |
>>> tensor_dst_transposed layout = {} | |
>>> tiled_tensor_src layout = {} | |
>>> tiled_tensor_dst_transposed layout = {} | |
>>> thread_layout = {} | |
>>> thread_layout_transposed = {}""", | |
M, | |
N, | |
TileSizeX, | |
TileSizeY, | |
global_memory_layout_src, | |
global_memory_layout_dst, | |
global_memory_layout_dst_transposed, | |
tensor_src.layout, | |
tensor_dst.layout, | |
tensor_dst_transposed.layout, | |
tiled_tensor_src.layout, | |
tiled_tensor_dst_transposed.layout, | |
thread_layout, | |
thread_layout_transposed, | |
) | |
if coalesced_read: | |
transpose_naive_kernel( | |
tiled_tensor_src, tiled_tensor_dst_transposed, thread_layout | |
).launch( | |
grid=[ | |
cute.size(tiled_tensor_dst_transposed, mode=[2]), | |
cute.size(tiled_tensor_dst_transposed, mode=[1]), | |
1, | |
], | |
block=[ThreadBlockSizeX * ThreadBlockSizeY, 1, 1], | |
) | |
else: | |
transpose_naive_kernel( | |
tiled_tensor_src, tiled_tensor_dst_transposed, thread_layout_transposed | |
).launch( | |
grid=[ | |
cute.size(tiled_tensor_dst_transposed, mode=[2]), | |
cute.size(tiled_tensor_dst_transposed, mode=[1]), | |
1, | |
], | |
block=[ThreadBlockSizeX * ThreadBlockSizeY, 1, 1], | |
) | |
M, N = 2048, 4096 | |
a = torch.randn(M, N, dtype=torch_dtype(cutlass.Float32), device="cuda") | |
ptr_a = cute_rt.make_ptr(cutlass.Float32, a.data_ptr()) | |
b = torch.randn(M, N, dtype=torch_dtype(cutlass.Float32), device="cuda") | |
ptr_b = cute_rt.make_ptr(cutlass.Float32, b.data_ptr()) | |
print("Launch with coalesced read") | |
transpose_naive(M, N, ptr_a, ptr_b, True) | |
print("Launch with coalesced write") | |
transpose_naive(M, N, ptr_a, ptr_b, False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment