Created
January 28, 2026 09:32
-
-
Save simveit/f8f538adacb5d4c2703600b843ba0547 to your computer and use it in GitHub Desktop.
Use constexpr for better performance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import cutlass | |
| import cutlass.cute as cute | |
| import cutlass.utils as utils | |
| import cutlass.pipeline as pipeline | |
| from cutlass.cute.nvgpu import cpasync, tcgen05 | |
| import cutlass.utils.blackwell_helpers as sm100_utils | |
| import cutlass.utils.blockscaled_layout as blockscaled_utils | |
| from cutlass.cute.runtime import make_ptr | |
| import functools | |
| from typing import Tuple, List | |
| import torch | |
| from task import input_t, output_t | |
| # Kernel configuration parameters | |
| # Size of tma descriptor in bytes | |
| bytes_per_tensormap = 128 | |
| # Number of tensormaps: a, b, sfa, sfb | |
| num_tensormaps = 4 | |
| # Tile sizes for M, N, K dimensions | |
| mma_tiler_mnk = (128, 128, 256) | |
| # Shape of the K dimension for the MMA instruction | |
| mma_inst_shape_k = 64 | |
| # FP4 data type for A and B | |
| ab_dtype = cutlass.Float4E2M1FN | |
| # FP8 data type for scale factors | |
| sf_dtype = cutlass.Float8E4M3FN | |
| # FP16 output type | |
| c_dtype = cutlass.Float16 | |
| # Scale factor block size (16 elements share one scale) | |
| sf_vec_size = 16 | |
| # Number of threads per CUDA thread block | |
| threads_per_cta = 128 | |
| # Stage numbers of shared memory and tmem | |
| num_acc_stage = 1 | |
| num_ab_stage = 6 | |
| # Total number of columns in tmem | |
| num_tmem_alloc_cols = 512 | |
| # Helper function for ceiling division | |
| def ceil_div(a, b): | |
| return (a + b - 1) // b | |
| # The CuTe reference implementation for NVFP4 block-scaled GEMM | |
| @cute.kernel | |
| def kernel( | |
| tiled_mma: cute.TiledMma, | |
| tma_atom_a: cute.CopyAtom, | |
| mA_mkl: cute.Tensor, | |
| tma_atom_b: cute.CopyAtom, | |
| mB_nkl: cute.Tensor, | |
| tma_atom_sfa: cute.CopyAtom, | |
| mSFA_mkl: cute.Tensor, | |
| tma_atom_sfb: cute.CopyAtom, | |
| mSFB_nkl: cute.Tensor, | |
| tensor_of_abc_ptrs: cute.Tensor, | |
| tensor_of_sfasfb_ptrs: cute.Tensor, | |
| tensormaps: cute.Tensor, | |
| tensor_of_problem_sizes: cute.Tensor, | |
| a_smem_layout_staged: cute.ComposedLayout, | |
| b_smem_layout_staged: cute.ComposedLayout, | |
| sfa_smem_layout_staged: cute.Layout, | |
| sfb_smem_layout_staged: cute.Layout, | |
| cta_m_list: cutlass.Constexpr[List[int]], | |
| cta_n_list: cutlass.Constexpr[List[int]], | |
| num_tma_load_bytes: cutlass.Constexpr[int], | |
| num_groups: cutlass.Constexpr[cutlass.Int32], | |
| ): | |
| """ | |
| GPU device kernel performing the Group GEMM computation. | |
| """ | |
| warp_idx = cute.arch.warp_idx() | |
| warp_idx = cute.arch.make_warp_uniform(warp_idx) | |
| tidx, _, _ = cute.arch.thread_idx() | |
| # | |
| # Delinearize bidz to coord_x, coord_y and group_idx for each CTA | |
| # | |
| bidx, bidy, bidz = cute.arch.block_idx() | |
| group_idx = 0 | |
| find = False | |
| coord_x = 0 | |
| coord_y = 0 | |
| cta_rest = bidz | |
| #for _, (cta_m, cta_n) in enumerate(cta_mn_list): | |
| for g in cutlass.range_constexpr(num_groups): | |
| cta_m = cta_m_list[g] | |
| cta_n = cta_n_list[g] | |
| if cta_rest >= (cta_m * cta_n): | |
| group_idx += 1 | |
| cta_rest -= cta_m * cta_n | |
| else: | |
| if not find: | |
| coord_y = cta_rest // cta_m | |
| coord_x = cta_rest % cta_m | |
| cta_rest -= cta_m * cta_n | |
| find = True | |
| # | |
| # Construct C Tensor for each CTA | |
| # | |
| mC_mnl_iter = cute.make_ptr( | |
| c_dtype, tensor_of_abc_ptrs[group_idx, 2], cute.AddressSpace.gmem | |
| ).align(32) | |
| m = tensor_of_problem_sizes[group_idx, 0] | |
| n = tensor_of_problem_sizes[group_idx, 1] | |
| k = tensor_of_problem_sizes[group_idx, 2] | |
| l = tensor_of_problem_sizes[group_idx, 3] | |
| mC_mnl_layout = cute.make_layout( | |
| (m, n, l), | |
| stride=(cute.assume(n, 32), 1, cute.assume(m * n, 32),)) | |
| mC_mnl = cute.make_tensor(mC_mnl_iter, mC_mnl_layout) | |
| # Local partition for global C Tensor | |
| # (bM, bN, RestM, RestN, RestL) | |
| gC_mnl = cute.local_tile( | |
| mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (coord_x, coord_y, 0) | |
| ) | |
| # | |
| # Define shared storage for kernel | |
| # | |
| size_tensormap_in_i64 = ( | |
| num_tensormaps * bytes_per_tensormap // 8 | |
| ) | |
| @cute.struct | |
| class SharedStorage: | |
| tensormap_buffer: cute.struct.MemRange[ | |
| cutlass.Int64, size_tensormap_in_i64 | |
| ] | |
| ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] | |
| acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] | |
| tmem_holding_buf: cutlass.Int32 | |
| smem = utils.SmemAllocator() | |
| storage = smem.allocate(SharedStorage) | |
| tensormap_smem_ptr = storage.tensormap_buffer.data_ptr() | |
| tensormap_a_smem_ptr = tensormap_smem_ptr | |
| tensormap_b_smem_ptr = ( | |
| tensormap_a_smem_ptr | |
| + bytes_per_tensormap // 8 | |
| ) | |
| tensormap_sfa_smem_ptr = ( | |
| tensormap_b_smem_ptr | |
| + bytes_per_tensormap // 8 | |
| ) | |
| tensormap_sfb_smem_ptr = ( | |
| tensormap_sfa_smem_ptr | |
| + bytes_per_tensormap // 8 | |
| ) | |
| # Setup smem tensor for A, B, SFA, SFB | |
| # (MMA, MMA_M, MMA_K, STAGE) | |
| sA = smem.allocate_tensor( | |
| element_type=ab_dtype, | |
| layout=a_smem_layout_staged.outer, | |
| byte_alignment=128, | |
| swizzle=a_smem_layout_staged.inner, | |
| ) | |
| # (MMA, MMA_N, MMA_K, STAGE) | |
| sB = smem.allocate_tensor( | |
| element_type=ab_dtype, | |
| layout=b_smem_layout_staged.outer, | |
| byte_alignment=128, | |
| swizzle=b_smem_layout_staged.inner, | |
| ) | |
| # (MMA, MMA_M, MMA_K, STAGE) | |
| sSFA = smem.allocate_tensor( | |
| element_type=sf_dtype, | |
| layout=sfa_smem_layout_staged, | |
| byte_alignment=128, | |
| ) | |
| # (MMA, MMA_N, MMA_K, STAGE) | |
| sSFB = smem.allocate_tensor( | |
| element_type=sf_dtype, | |
| layout=sfb_smem_layout_staged, | |
| byte_alignment=128, | |
| ) | |
| # Initialize mainloop ab_pipeline, acc_pipeline and their states | |
| ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) | |
| ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) | |
| ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( | |
| barrier_storage=storage.ab_mbar_ptr.data_ptr(), | |
| num_stages=num_ab_stage, | |
| producer_group=ab_pipeline_producer_group, | |
| consumer_group=ab_pipeline_consumer_group, | |
| tx_count=num_tma_load_bytes, | |
| ).make_participants() | |
| acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( | |
| barrier_storage=storage.acc_mbar_ptr.data_ptr(), | |
| num_stages=num_acc_stage, | |
| producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), | |
| consumer_group=pipeline.CooperativeGroup( | |
| pipeline.Agent.Thread, | |
| threads_per_cta, | |
| ), | |
| ).make_participants() | |
| # | |
| # Local_tile partition global tensors | |
| # | |
| # (bM, bK, RestM, RestK, RestL) | |
| gA_mkl = cute.local_tile( | |
| mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) | |
| ) | |
| # (bN, bK, RestN, RestK, RestL) | |
| gB_nkl = cute.local_tile( | |
| mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) | |
| ) | |
| # (bM, bK, RestM, RestK, RestL) | |
| gSFA_mkl = cute.local_tile( | |
| mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) | |
| ) | |
| # (bN, bK, RestN, RestK, RestL) | |
| gSFB_nkl = cute.local_tile( | |
| mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) | |
| ) | |
| # | |
| # Partition global tensor for TiledMMA_A/B/C | |
| # | |
| thr_mma = tiled_mma.get_slice(tidx) | |
| # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) | |
| tCgA = thr_mma.partition_A(gA_mkl) | |
| # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) | |
| tCgB = thr_mma.partition_B(gB_nkl) | |
| # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) | |
| tCgSFA = thr_mma.partition_A(gSFA_mkl) | |
| # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) | |
| tCgSFB = thr_mma.partition_B(gSFB_nkl) | |
| # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) | |
| tCgC = thr_mma.partition_C(gC_mnl) | |
| # Update tma descriptor with the correct shapes and strides | |
| tensormap_manager = utils.TensorMapManager( | |
| utils.TensorMapUpdateMode.SMEM, | |
| 128, | |
| ) | |
| tensormap_a_gmem_ptr = tensormap_manager.get_tensormap_ptr( | |
| tensormaps[(bidz, 0, None)].iterator | |
| ) | |
| tensormap_b_gmem_ptr = tensormap_manager.get_tensormap_ptr( | |
| tensormaps[(bidz, 1, None)].iterator | |
| ) | |
| tensormap_sfa_gmem_ptr = tensormap_manager.get_tensormap_ptr( | |
| tensormaps[(bidz, 2, None)].iterator | |
| ) | |
| tensormap_sfb_gmem_ptr = tensormap_manager.get_tensormap_ptr( | |
| tensormaps[(bidz, 3, None)].iterator | |
| ) | |
| mA_mkl_iter = cute.make_ptr( | |
| ab_dtype, tensor_of_abc_ptrs[group_idx, 0], cute.AddressSpace.gmem | |
| ).align(32) | |
| mB_nkl_iter = cute.make_ptr( | |
| ab_dtype, tensor_of_abc_ptrs[group_idx, 1], cute.AddressSpace.gmem | |
| ).align(32) | |
| sfa_mkl_iter = cute.make_ptr( | |
| sf_dtype, tensor_of_sfasfb_ptrs[group_idx, 0], cute.AddressSpace.gmem | |
| ).align(32) | |
| sfb_nkl_iter = cute.make_ptr( | |
| sf_dtype, tensor_of_sfasfb_ptrs[group_idx, 1], cute.AddressSpace.gmem | |
| ).align(32) | |
| mA_mkl_layout = cute.make_layout( | |
| (m, k, l), stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32),)) | |
| mB_nkl_layout = cute.make_layout( | |
| (n, k, l), stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32),)) | |
| # SFA, SFB follows specialized layout defined in the following link: | |
| # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout | |
| atom_shape = ((32, 4), (sf_vec_size, 4)) | |
| atom_stride = ((16, 4), (0, 1)) | |
| sfa_layout = cute.tile_to_shape( | |
| cute.make_layout(atom_shape, stride=atom_stride), | |
| mA_mkl_layout.shape, | |
| (2, 1, 3), | |
| ) | |
| sfb_layout = cute.tile_to_shape( | |
| cute.make_layout(atom_shape, stride=atom_stride), | |
| mB_nkl_layout.shape, | |
| (2, 1, 3), | |
| ) | |
| real_tensor_a = cute.make_tensor(mA_mkl_iter, mA_mkl_layout) | |
| real_tensor_b = cute.make_tensor(mB_nkl_iter, mB_nkl_layout) | |
| real_tensor_sfa = cute.make_tensor(sfa_mkl_iter, sfa_layout) | |
| real_tensor_sfb = cute.make_tensor(sfb_nkl_iter, sfb_layout) | |
| # Let warp 0 initialize tensormap | |
| if warp_idx == 0: | |
| tensormap_manager.init_tensormap_from_atom( | |
| tma_atom_a, tensormap_a_smem_ptr, 0 | |
| ) | |
| tensormap_manager.init_tensormap_from_atom( | |
| tma_atom_b, tensormap_b_smem_ptr, 0 | |
| ) | |
| tensormap_manager.init_tensormap_from_atom( | |
| tma_atom_sfa, tensormap_sfa_smem_ptr, 0 | |
| ) | |
| tensormap_manager.init_tensormap_from_atom( | |
| tma_atom_sfb, tensormap_sfb_smem_ptr, 0 | |
| ) | |
| tensormap_manager.update_tensormap( | |
| ( | |
| real_tensor_a, | |
| real_tensor_b, | |
| real_tensor_sfa, | |
| real_tensor_sfb, | |
| ), | |
| (tma_atom_a, tma_atom_b, tma_atom_sfa, tma_atom_sfb), | |
| ( | |
| tensormap_a_gmem_ptr, | |
| tensormap_b_gmem_ptr, | |
| tensormap_sfa_gmem_ptr, | |
| tensormap_sfb_gmem_ptr, | |
| ), | |
| 0, # tma warp id | |
| ( | |
| tensormap_a_smem_ptr, | |
| tensormap_b_smem_ptr, | |
| tensormap_sfa_smem_ptr, | |
| tensormap_sfb_smem_ptr, | |
| ), | |
| ) | |
| tensormap_manager.fence_tensormap_update(tensormap_a_gmem_ptr) | |
| tensormap_manager.fence_tensormap_update(tensormap_b_gmem_ptr) | |
| tensormap_manager.fence_tensormap_update(tensormap_sfa_gmem_ptr) | |
| tensormap_manager.fence_tensormap_update(tensormap_sfb_gmem_ptr) | |
| cute.arch.barrier() | |
| # | |
| # Partition global/shared tensor for TMA load A/B/SFA/SFB | |
| # | |
| # TMA Partition_S/D for A | |
| # ((atom_v, rest_v), STAGE) | |
| # ((atom_v, rest_v), RestM, RestK, RestL) | |
| tAsA, tAgA = cpasync.tma_partition( | |
| tma_atom_a, | |
| 0, | |
| cute.make_layout(1), | |
| cute.group_modes(sA, 0, 3), | |
| cute.group_modes(tCgA, 0, 3), | |
| ) | |
| # TMA Partition_S/D for B | |
| # ((atom_v, rest_v), STAGE) | |
| # ((atom_v, rest_v), RestN, RestK, RestL) | |
| tBsB, tBgB = cpasync.tma_partition( | |
| tma_atom_b, | |
| 0, | |
| cute.make_layout(1), | |
| cute.group_modes(sB, 0, 3), | |
| cute.group_modes(tCgB, 0, 3), | |
| ) | |
| # TMA Partition_S/D for SFA | |
| # ((atom_v, rest_v), STAGE) | |
| # ((atom_v, rest_v), RestM, RestK, RestL) | |
| tAsSFA, tAgSFA = cpasync.tma_partition( | |
| tma_atom_sfa, | |
| 0, | |
| cute.make_layout(1), | |
| cute.group_modes(sSFA, 0, 3), | |
| cute.group_modes(tCgSFA, 0, 3), | |
| ) | |
| tAsSFA = cute.filter_zeros(tAsSFA) | |
| tAgSFA = cute.filter_zeros(tAgSFA) | |
| # TMA Partition_S/D for SFB | |
| # ((atom_v, rest_v), STAGE) | |
| # ((atom_v, rest_v), RestN, RestK, RestL) | |
| tBsSFB, tBgSFB = cpasync.tma_partition( | |
| tma_atom_sfb, | |
| 0, | |
| cute.make_layout(1), | |
| cute.group_modes(sSFB, 0, 3), | |
| cute.group_modes(tCgSFB, 0, 3), | |
| ) | |
| tBsSFB = cute.filter_zeros(tBsSFB) | |
| tBgSFB = cute.filter_zeros(tBgSFB) | |
| # | |
| # Partition shared/tensor memory tensor for TiledMMA_A/B/C | |
| # | |
| # (MMA, MMA_M, MMA_K, STAGE) | |
| tCrA = tiled_mma.make_fragment_A(sA) | |
| # (MMA, MMA_N, MMA_K, STAGE) | |
| tCrB = tiled_mma.make_fragment_B(sB) | |
| # (MMA, MMA_M, MMA_N) | |
| acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) | |
| # (MMA, MMA_M, MMA_N) | |
| tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) | |
| # | |
| # Alloc tensor memory buffer | |
| # | |
| tmem_alloc_barrier = pipeline.NamedBarrier( | |
| barrier_id=1, | |
| num_threads=threads_per_cta, | |
| ) | |
| tmem = utils.TmemAllocator( | |
| storage.tmem_holding_buf, | |
| barrier_for_retrieve=tmem_alloc_barrier, | |
| ) | |
| tmem.allocate(num_tmem_alloc_cols) | |
| tmem.wait_for_alloc() | |
| acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) | |
| tCtAcc = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) | |
| # | |
| # Make SFA/SFB tmem tensor | |
| # | |
| # Get SFA tmem ptr | |
| sfa_tmem_ptr = cute.recast_ptr( | |
| acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc), | |
| dtype=sf_dtype, | |
| ) | |
| # (MMA, MMA_M, MMA_K) | |
| tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( | |
| tiled_mma, | |
| mma_tiler_mnk, | |
| sf_vec_size, | |
| cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), | |
| ) | |
| tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) | |
| # Get SFB tmem ptr | |
| sfb_tmem_ptr = cute.recast_ptr( | |
| acc_tmem_ptr | |
| + tcgen05.find_tmem_tensor_col_offset(tCtAcc) | |
| + tcgen05.find_tmem_tensor_col_offset(tCtSFA), | |
| dtype=sf_dtype, | |
| ) | |
| # (MMA, MMA_N, MMA_K) | |
| tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( | |
| tiled_mma, | |
| mma_tiler_mnk, | |
| sf_vec_size, | |
| cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), | |
| ) | |
| tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) | |
| # | |
| # Partition for S2T copy of SFA/SFB | |
| # | |
| # Make S2T CopyAtom | |
| copy_atom_s2t = cute.make_copy_atom( | |
| tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), | |
| sf_dtype, | |
| ) | |
| # (MMA, MMA_MN, MMA_K, STAGE) | |
| tCsSFA_compact = cute.filter_zeros(sSFA) | |
| tCtSFA_compact = cute.filter_zeros(tCtSFA) | |
| tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) | |
| thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) | |
| # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) | |
| tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) | |
| # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) | |
| tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( | |
| tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ | |
| ) | |
| # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) | |
| tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) | |
| # (MMA, MMA_MN, MMA_K, STAGE) | |
| tCsSFB_compact = cute.filter_zeros(sSFB) | |
| # (MMA, MMA_MN, MMA_K) | |
| tCtSFB_compact = cute.filter_zeros(tCtSFB) | |
| tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB_compact) | |
| thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) | |
| # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) | |
| tCsSFB_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB_compact) | |
| # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) | |
| tCsSFB_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( | |
| tiled_copy_s2t_sfb, tCsSFB_compact_s2t_ | |
| ) | |
| # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) | |
| tCtSFB_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB_compact) | |
| # Number of K loops | |
| k_tile_cnt = cute.ceil_div(real_tensor_a.shape[1], mma_tiler_mnk[2]) | |
| # | |
| # Slice to per mma tile index | |
| # | |
| mma_tile_coord_mnl = (coord_x, coord_y, 0) | |
| # ((atom_v, rest_v), RestK) | |
| tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] | |
| # ((atom_v, rest_v), RestK) | |
| tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] | |
| # ((atom_v, rest_v), RestK) | |
| tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] | |
| # ((atom_v, rest_v), RestK) | |
| tBgSFB = tBgSFB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] | |
| # | |
| # Main loop | |
| # | |
| if warp_idx == 0: | |
| # Execute k_tile loop | |
| for k_tile in range(k_tile_cnt): | |
| # Wait for AB buffer empty | |
| ab_empty = ab_producer.acquire_and_advance() | |
| # TMA load A/B/SFA/SFB to shared memory | |
| cute.copy( | |
| tma_atom_a, | |
| tAgA[(None, k_tile)], | |
| tAsA[(None, ab_empty.index)], | |
| tma_bar_ptr=ab_empty.barrier, | |
| tma_desc_ptr=tensormap_manager.get_tensormap_ptr( | |
| tensormap_a_gmem_ptr, | |
| cute.AddressSpace.generic, | |
| ), | |
| ) | |
| cute.copy( | |
| tma_atom_b, | |
| tBgB[(None, k_tile)], | |
| tBsB[(None, ab_empty.index)], | |
| tma_bar_ptr=ab_empty.barrier, | |
| tma_desc_ptr=tensormap_manager.get_tensormap_ptr( | |
| tensormap_b_gmem_ptr, | |
| cute.AddressSpace.generic, | |
| ), | |
| ) | |
| cute.copy( | |
| tma_atom_sfa, | |
| tAgSFA[(None, k_tile)], | |
| tAsSFA[(None, ab_empty.index)], | |
| tma_bar_ptr=ab_empty.barrier, | |
| tma_desc_ptr=tensormap_manager.get_tensormap_ptr( | |
| tensormap_sfa_gmem_ptr, | |
| cute.AddressSpace.generic, | |
| ), | |
| ) | |
| cute.copy( | |
| tma_atom_sfb, | |
| tBgSFB[(None, k_tile)], | |
| tBsSFB[(None, ab_empty.index)], | |
| tma_bar_ptr=ab_empty.barrier, | |
| tma_desc_ptr=tensormap_manager.get_tensormap_ptr( | |
| tensormap_sfb_gmem_ptr, | |
| cute.AddressSpace.generic, | |
| ), | |
| ) | |
| if warp_idx == 1: | |
| # Wait for accumulator buffer empty | |
| acc_empty = acc_producer.acquire_and_advance() | |
| # Set ACCUMULATE field to False for the first k_tile iteration | |
| tiled_mma.set(tcgen05.Field.ACCUMULATE, False) | |
| # Execute k_tile loop | |
| for k_tile in range(k_tile_cnt): | |
| # Wait for AB buffer full | |
| ab_full = ab_consumer.wait_and_advance() | |
| # Copy SFA/SFB from shared memory to TMEM | |
| s2t_stage_coord = (None, None, None, None, ab_full.index) | |
| tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] | |
| tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] | |
| cute.copy( | |
| tiled_copy_s2t_sfa, | |
| tCsSFA_compact_s2t_staged, | |
| tCtSFA_compact_s2t, | |
| ) | |
| cute.copy( | |
| tiled_copy_s2t_sfb, | |
| tCsSFB_compact_s2t_staged, | |
| tCtSFB_compact_s2t, | |
| ) | |
| # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB | |
| num_kblocks = cute.size(tCrA, mode=[2]) | |
| for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): | |
| kblock_coord = ( | |
| None, | |
| None, | |
| kblock_idx, | |
| ab_full.index, | |
| ) | |
| # Set SFA/SFB tensor to tiled_mma | |
| sf_kblock_coord = (None, None, kblock_idx) | |
| tiled_mma.set( | |
| tcgen05.Field.SFA, | |
| tCtSFA[sf_kblock_coord].iterator, | |
| ) | |
| tiled_mma.set( | |
| tcgen05.Field.SFB, | |
| tCtSFB[sf_kblock_coord].iterator, | |
| ) | |
| cute.gemm( | |
| tiled_mma, | |
| tCtAcc, | |
| tCrA[kblock_coord], | |
| tCrB[kblock_coord], | |
| tCtAcc, | |
| ) | |
| # Enable accumulate on tCtAcc after first kblock | |
| tiled_mma.set(tcgen05.Field.ACCUMULATE, True) | |
| # Async arrive AB buffer empty | |
| ab_full.release() | |
| acc_empty.commit() | |
| # | |
| # Epilogue | |
| # Partition for epilogue | |
| # | |
| op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) | |
| copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) | |
| tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[None,0,0]) | |
| thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) | |
| # (TmemCpy, NumTmemCpy) | |
| tDtAcc = thr_copy_t2r.partition_S(tCtAcc[None,0,0]) | |
| # (TmemCpy, NumTmemCpy) | |
| tDgC = thr_copy_t2r.partition_D(tCgC[None,0,0]) | |
| # (TmemCpy, NumTmemCpy) | |
| tDrAcc = cute.make_rmem_tensor(tDgC.shape, cutlass.Float32) | |
| # (TmemCpy, NumTmemCpy) | |
| tDrC = cute.make_rmem_tensor(tDgC.shape, c_dtype) | |
| # Release TMEM allocation lock | |
| tmem.relinquish_alloc_permit() | |
| # Wait for accumulator buffer full | |
| acc_full = acc_consumer.wait_and_advance() | |
| # Copy accumulator to register | |
| cute.copy(tiled_copy_t2r, tDtAcc, tDrAcc) | |
| acc_vec = tDrAcc.load() | |
| tDrC.store(acc_vec.to(c_dtype)) | |
| # STG Atom, just to ensure functionality | |
| # For performance optimization, better to use Tma store operation to | |
| # reduce address calculation and predicate calulation instructions | |
| simt_atom = cute.make_copy_atom( | |
| cute.nvgpu.CopyUniversalOp(), c_dtype, num_bits_per_copy=16 | |
| ) | |
| thread_layout = cute.make_layout( | |
| (1, threads_per_cta), stride=(threads_per_cta, 1)) | |
| value_layout = cute.make_layout((1, 1)) | |
| tiled_copy_r2g = cute.make_tiled_copy_tv( | |
| simt_atom, thread_layout, value_layout | |
| ) | |
| thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) | |
| cC = cute.make_identity_tensor(gC_mnl.shape) | |
| # ((atom_v, rest_v), NumGmemCpy) | |
| tDcC = thr_copy_r2g.partition_D(cC) | |
| # ((atom_v, rest_v), NumGmemCpy) | |
| tDpC = cute.make_rmem_tensor(tDrC.shape, cutlass.Boolean) | |
| residue_m = mC_mnl.shape[0] - cutlass.Int32(coord_x) * mma_tiler_mnk[0] | |
| residue_n = mC_mnl.shape[1] - cutlass.Int32(coord_y) * mma_tiler_mnk[1] | |
| for i in range(cute.size(tDrC.shape)): | |
| # Swap residue_m and residue_n to match the order of tDcC | |
| tDpC[i] = cute.elem_less(tDcC[i], (residue_n, residue_m)) | |
| cute.copy(simt_atom, cute.flatten(tDrC), cute.flatten(tDgC), pred=cute.flatten(tDpC)) | |
| acc_full.release() | |
| # Deallocate TMEM | |
| cute.arch.barrier() | |
| tmem.free(acc_tmem_ptr) | |
| pass | |
| # Host-side JIT function to prepare tensors and launch GPU kernel. | |
| @cute.jit | |
| def my_kernel( | |
| ptr_of_tensor_of_problem_sizes: cute.Pointer, | |
| ptr_of_tensor_of_abc_ptrs: cute.Pointer, | |
| ptr_of_tensor_of_sfasfb_ptrs: cute.Pointer, | |
| ptr_of_tensor_of_tensormap: cute.Pointer, | |
| total_num_clusters: cutlass.Int32, | |
| problem_sizes: cutlass.Constexpr[List[ | |
| Tuple[int, int, int, int] | |
| ]], # Problem sizes for each group | |
| num_groups: cutlass.Constexpr[cutlass.Int32], | |
| ): | |
| tensor_of_abc_ptrs = cute.make_tensor( | |
| ptr_of_tensor_of_abc_ptrs, cute.make_layout((num_groups, 3), stride=(3, 1)) | |
| ) | |
| tensor_of_sfasfb_ptrs = cute.make_tensor( | |
| ptr_of_tensor_of_sfasfb_ptrs, cute.make_layout((num_groups, 2), stride=(2, 1)) | |
| ) | |
| tensor_of_problem_sizes = cute.make_tensor( | |
| ptr_of_tensor_of_problem_sizes, cute.make_layout((num_groups, 4), stride=(4, 1)) | |
| ) | |
| tensor_of_tensormap = cute.make_tensor( | |
| ptr_of_tensor_of_tensormap, cute.make_layout((total_num_clusters, 4, 16), stride=(64, 16, 1)) | |
| ) | |
| # Use fake shape for initial Tma descriptor and atom setup | |
| # The real Tma desc and atom will be updated during kernel execution. | |
| min_a_shape = (cutlass.Int32(64), cutlass.Int32(64), cutlass.Int32(64), cutlass.Int32(1)) | |
| min_b_shape = (cutlass.Int32(64), cutlass.Int32(64), cutlass.Int32(64), cutlass.Int32(1)) | |
| initial_a = cute.make_tensor( | |
| cute.make_ptr(ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16,), | |
| cute.make_layout( | |
| (min_a_shape[0], cute.assume(min_a_shape[2], 32), min_a_shape[3]), | |
| stride=( | |
| cute.assume(min_a_shape[2], 32), | |
| 1, | |
| cute.assume(min_a_shape[0] * min_a_shape[2], 32), | |
| ), | |
| ), | |
| ) | |
| initial_b = cute.make_tensor( | |
| cute.make_ptr(ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16,), | |
| cute.make_layout( | |
| (min_b_shape[1], cute.assume(min_b_shape[2], 32), min_b_shape[3]), | |
| stride=( | |
| cute.assume(min_b_shape[2], 32), | |
| 1, | |
| cute.assume(min_b_shape[1] * min_b_shape[2], 32), | |
| ), | |
| ), | |
| ) | |
| # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout | |
| # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) | |
| sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( | |
| initial_a.shape, sf_vec_size | |
| ) | |
| # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) | |
| sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( | |
| initial_b.shape, sf_vec_size | |
| ) | |
| # Create initial SFA and SFB tensors with fake shape and null pointer. | |
| initial_sfa = cute.make_tensor( | |
| cute.make_ptr(sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=16,), sfa_layout) | |
| initial_sfb = cute.make_tensor( | |
| cute.make_ptr(sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=16,), sfb_layout) | |
| # Select MMA operation | |
| mma_op = tcgen05.MmaMXF4NVF4Op( | |
| sf_dtype, | |
| (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), | |
| tcgen05.CtaGroup.ONE, | |
| tcgen05.OperandSource.SMEM, | |
| ) | |
| tiled_mma = cute.make_tiled_mma(mma_op) | |
| cluster_layout_vmnk = cute.tiled_divide( | |
| cute.make_layout((1, 1, 1)), | |
| (tiled_mma.thr_id.shape,), | |
| ) | |
| # Compute A/B/SFA/SFB/C shared memory layout | |
| a_smem_layout_staged = sm100_utils.make_smem_layout_a( | |
| tiled_mma, | |
| mma_tiler_mnk, | |
| ab_dtype, | |
| num_ab_stage, | |
| ) | |
| b_smem_layout_staged = sm100_utils.make_smem_layout_b( | |
| tiled_mma, | |
| mma_tiler_mnk, | |
| ab_dtype, | |
| num_ab_stage, | |
| ) | |
| sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( | |
| tiled_mma, | |
| mma_tiler_mnk, | |
| sf_vec_size, | |
| num_ab_stage, | |
| ) | |
| sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( | |
| tiled_mma, | |
| mma_tiler_mnk, | |
| sf_vec_size, | |
| num_ab_stage, | |
| ) | |
| atom_thr_size = cute.size(tiled_mma.thr_id.shape) | |
| # Setup TMA for A | |
| a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) | |
| tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( | |
| cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), | |
| initial_a, | |
| a_smem_layout, | |
| mma_tiler_mnk, | |
| tiled_mma, | |
| cluster_layout_vmnk.shape, | |
| ) | |
| # Setup TMA for B | |
| b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) | |
| tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( | |
| cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), | |
| initial_b, | |
| b_smem_layout, | |
| mma_tiler_mnk, | |
| tiled_mma, | |
| cluster_layout_vmnk.shape, | |
| ) | |
| # Setup TMA for SFA | |
| sfa_smem_layout = cute.slice_( | |
| sfa_smem_layout_staged, (None, None, None, 0) | |
| ) | |
| tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( | |
| cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), | |
| initial_sfa, | |
| sfa_smem_layout, | |
| mma_tiler_mnk, | |
| tiled_mma, | |
| cluster_layout_vmnk.shape, | |
| internal_type=cutlass.Int16, | |
| ) | |
| # Setup TMA for SFB | |
| sfb_smem_layout = cute.slice_( | |
| sfb_smem_layout_staged, (None, None, None, 0) | |
| ) | |
| tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( | |
| cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), | |
| initial_sfb, | |
| sfb_smem_layout, | |
| mma_tiler_mnk, | |
| tiled_mma, | |
| cluster_layout_vmnk.shape, | |
| internal_type=cutlass.Int16, | |
| ) | |
| # Compute TMA load bytes | |
| a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) | |
| b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) | |
| sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) | |
| sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) | |
| num_tma_load_bytes = ( | |
| a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size | |
| ) * atom_thr_size | |
| # Store CTA shape information for each Group in a List | |
| cta_m_list = [] | |
| cta_n_list = [] | |
| for group_idx in cutlass.range_constexpr(num_groups): | |
| x, y = cute.ceil_div(problem_sizes[group_idx][:2], mma_tiler_mnk[0:2]) | |
| cta_m_list.append(x) | |
| cta_n_list.append(y) | |
| # Compute grid size | |
| grid = (1, 1, total_num_clusters) | |
| # Launch the kernel | |
| kernel( | |
| # MMA (Matrix Multiply-Accumulate) configuration | |
| tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern | |
| # TMA (Tensor Memory Accelerator) atoms and tensors for input matrix A | |
| tma_atom_a, # TMA copy atom defining how to load A from global memory | |
| tma_tensor_a, # Tensor descriptor for A (created from smallest A tensor) | |
| # TMA atoms and tensors for input matrix B | |
| tma_atom_b, # TMA copy atom defining how to load B from global memory | |
| tma_tensor_b, # Tensor descriptor for B (created from smallest B tensor) | |
| # TMA atoms and tensors for scale factor A | |
| tma_atom_sfa, # TMA copy atom for loading scale factors for A | |
| tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) | |
| # TMA atoms and tensors for scale factor B | |
| tma_atom_sfb, # TMA copy atom for loading scale factors for B | |
| tma_tensor_sfb, # Tensor descriptor for SFB (block scale factors for B) | |
| # Runtime tensor metadata for dynamic group access | |
| tensor_of_abc_ptrs, # Device tensor containing pointers to A, B, C for all groups | |
| tensor_of_sfasfb_ptrs, # Device tensor containing pointers to SFA, SFB for all groups | |
| tensor_of_tensormap, # Pre-allocated buffer for tensormap descriptors per CTA | |
| tensor_of_problem_sizes, # Device tensor containing (m, n, k, l) for each group | |
| # Shared memory layouts with staging for pipelined execution | |
| a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) | |
| b_smem_layout_staged, # Staged shared memory layout for B (includes stage dimension) | |
| sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) | |
| sfb_smem_layout_staged, # Staged shared memory layout for SFB (includes stage dimension) | |
| # CTA grid configuration per group | |
| #cta_mn_list, # List of (M_tiles, N_tiles) for each group | |
| cta_m_list, # List of (M_tiles) for each group | |
| cta_n_list, # List of (N_tiles) for each group | |
| # Pipeline synchronization parameter | |
| num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) | |
| num_groups, # Number of groups (known at compile time. | |
| ).launch( | |
| grid=grid, | |
| block=[threads_per_cta, 1, 1], | |
| cluster=(1, 1, 1), | |
| ) | |
| return | |
| # Global cache for compiled kernels (keyed by group size) | |
| _compiled_kernel_cache = {} | |
| # This function is used to compile the kernel once and cache it and then allow users to | |
| # run the kernel multiple times to get more accurate timing results. | |
| def compile_kernel(problem_sizes): | |
| """ | |
| Compile the kernel once and cache it using problem_sizes as the key. | |
| This should be called before any timing measurements. | |
| Returns: | |
| The compiled kernel function | |
| """ | |
| global _compiled_kernel_cache | |
| # Convert problem_sizes list to a hashable tuple for use as dictionary key | |
| cache_key = tuple(tuple([ps for ps in problem_sizes])) | |
| # Check if we already have a compiled kernel for these problem sizes | |
| if cache_key in _compiled_kernel_cache: | |
| return _compiled_kernel_cache[cache_key] | |
| cute_ptr_of_tensor_of_problem_sizes = make_ptr( | |
| cutlass.Int32, 0, cute.AddressSpace.gmem, assumed_align=16, | |
| ) | |
| cute_ptr_of_tensor_of_abc_ptrs = make_ptr( | |
| cutlass.Int64, 0, cute.AddressSpace.gmem, assumed_align=16, | |
| ) | |
| cute_ptr_of_tensor_of_sfasfb_ptrs = make_ptr( | |
| cutlass.Int64, 0, cute.AddressSpace.gmem, assumed_align=16, | |
| ) | |
| # Fake cluster numbers for compile only. | |
| total_num_clusters = cutlass.Int32(1) | |
| num_groups = cutlass.Int32(len(problem_sizes)) | |
| # Each cluster needs its own set of tensormaps (one for A, B, SFA, SFB) | |
| # Shape: (total_num_clusters, num_tensormaps=4, bytes_per_tensormap/8=16) | |
| cute_ptr_of_tensor_of_tensormap = make_ptr( | |
| cutlass.Int64, 0, cute.AddressSpace.gmem, assumed_align=16, | |
| ) | |
| compiled_func = cute.compile( | |
| my_kernel, | |
| cute_ptr_of_tensor_of_problem_sizes, | |
| cute_ptr_of_tensor_of_abc_ptrs, | |
| cute_ptr_of_tensor_of_sfasfb_ptrs, | |
| cute_ptr_of_tensor_of_tensormap, | |
| total_num_clusters, | |
| problem_sizes, | |
| num_groups | |
| ) | |
| # Store compiled kernel in cache with problem_sizes as key | |
| _compiled_kernel_cache[cache_key] = compiled_func | |
| return compiled_func | |
| def custom_kernel(data: input_t) -> output_t: | |
| """ | |
| Execute the block-scaled group GEMM kernel. | |
| This is the main entry point called by the evaluation framework. | |
| It converts PyTorch tensors to CuTe tensors, launches the kernel, | |
| and returns the result. | |
| Args: | |
| data: Tuple of (abc_tensors, sfasfb_tensors, problem_sizes) where: | |
| abc_tensors: list of tuples (a, b, c) where | |
| a is torch.Tensor[float4e2m1fn_x2] of shape [m, k // 2, l] | |
| b is torch.Tensor[float4e2m1fn_x2] of shape [n, k // 2, l] | |
| c is torch.Tensor[float16] of shape [m, n, l] | |
| sfasfb_tensors: list of tuples (sfa, sfb) where | |
| sfa is torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l] | |
| sfb is torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l] | |
| problem_sizes: list of tuples (m, n, k, l) | |
| each group has its own a, b, c, sfa, sfb with different m, n, k, l problem sizes | |
| l should always be 1 for each group. | |
| list size is the number of groups. | |
| Returns: | |
| list of c tensors where c is torch.Tensor[float16] of shape [m, n, l] for each group | |
| """ | |
| abc_tensors, _, sfasfb_reordered_tensors, problem_sizes = data | |
| compiled_func = compile_kernel(problem_sizes) | |
| # Extract raw data pointers from all input tensors for each group | |
| # These will be passed to the GPU kernel to access the actual tensor data | |
| abc_ptrs = [] | |
| sfasfb_ptrs = [] | |
| for i, ((a, b, c), (sfa_reordered, sfb_reordered), (m, n, k, l)) in enumerate(zip(abc_tensors, sfasfb_reordered_tensors, problem_sizes)): | |
| # Store pointers to A, B, and C matrices for this group | |
| abc_ptrs.append((a.data_ptr(), b.data_ptr(), c.data_ptr())) | |
| # Store pointers to scale factor tensors for this group | |
| sfasfb_ptrs.append((sfa_reordered.data_ptr(), sfb_reordered.data_ptr())) | |
| # Create torch tensor to store problem sizes for all groups | |
| # Shape: (num_groups, 4) where each row contains (m, n, k, l) for that group | |
| # Layout: (num_groups, 4):(4, 1) means row-major storage | |
| tensor_of_problem_sizes = torch.tensor( | |
| problem_sizes, dtype=torch.int32, device="cuda" | |
| ) | |
| # Create torch tensors to store data pointers for all groups | |
| # These allow the GPU kernel to dynamically access different tensors per group | |
| # tensor_of_abc_ptrs: Shape (num_groups, 3) containing (a_ptr, b_ptr, c_ptr) per group | |
| # tensor_of_sfasfb_ptrs: Shape (num_groups, 2) containing (sfa_ptr, sfb_ptr) per group | |
| tensor_of_abc_ptrs = torch.tensor(abc_ptrs, dtype=torch.int64, device="cuda") | |
| tensor_of_sfasfb_ptrs = torch.tensor(sfasfb_ptrs, dtype=torch.int64, device="cuda") | |
| # Compute the tile shape for each CUDA Thread Block (CTA) | |
| # cta_tile_shape_mn: [M_tile, N_tile] = [128, 128] for this kernel | |
| cta_tile_shape_mn = [128, mma_tiler_mnk[1]] | |
| # cluster_tile_shape_mn: Total tile shape per cluster (same as CTA since cluster is 1x1) | |
| cluster_tile_shape_mn = tuple( | |
| x * y for x, y in zip(cta_tile_shape_mn, (1, 1)) | |
| ) | |
| # Compute total number of cluster tiles needed across all groups | |
| # Each group's (m, n) dimensions are divided into tiles of size cluster_tile_shape_mn | |
| # This determines the total grid size (bidz dimension) for kernel launch | |
| total_num_clusters = 0 | |
| num_groups = len(problem_sizes) | |
| for m, n, _, _ in problem_sizes: | |
| # Calculate number of tiles needed in M and N dimensions for this group | |
| num_clusters_mn = tuple( | |
| (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) | |
| ) | |
| # Multiply M_tiles * N_tiles to get total tiles for this group | |
| total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) | |
| # Allocate device memory for tensormap descriptors | |
| # Each cluster needs its own set of tensormaps (one for A, B, SFA, SFB) | |
| # Shape: (total_num_clusters, num_tensormaps=4, bytes_per_tensormap/8=16) | |
| # Tensormaps are hardware descriptors used by TMA for efficient memory transfers | |
| tensormap_shape = ( | |
| total_num_clusters, | |
| num_tensormaps, | |
| bytes_per_tensormap // 8, | |
| ) | |
| tensor_of_tensormap = torch.empty(tensormap_shape, dtype=torch.int64, device="cuda") | |
| # Create CuTe pointers to the metadata tensors that will be passed to the kernel | |
| # These allow the GPU kernel to read problem sizes and tensor pointers | |
| cute_ptr_of_tensor_of_abc_ptrs = make_ptr( | |
| cutlass.Int64, | |
| tensor_of_abc_ptrs.data_ptr(), | |
| cute.AddressSpace.gmem, | |
| assumed_align=16, | |
| ) | |
| cute_ptr_of_tensor_of_sfasfb_ptrs = make_ptr( | |
| cutlass.Int64, | |
| tensor_of_sfasfb_ptrs.data_ptr(), | |
| cute.AddressSpace.gmem, | |
| assumed_align=16, | |
| ) | |
| cute_ptr_of_tensor_of_problem_sizes = make_ptr( | |
| cutlass.Int32, | |
| tensor_of_problem_sizes.data_ptr(), | |
| cute.AddressSpace.gmem, | |
| assumed_align=16, | |
| ) | |
| cute_ptr_of_tensor_of_tensormap = make_ptr( | |
| cutlass.Int64, | |
| tensor_of_tensormap.data_ptr(), | |
| cute.AddressSpace.gmem, | |
| assumed_align=16, | |
| ) | |
| # Launch the JIT-compiled GPU kernel with all prepared data | |
| # The kernel will perform block-scaled group GEMM: C = A * SFA * B * SFB for all groups | |
| compiled_func( | |
| cute_ptr_of_tensor_of_problem_sizes, # Pointer to problem sizes array | |
| cute_ptr_of_tensor_of_abc_ptrs, # Pointer to ABC tensor pointers array | |
| cute_ptr_of_tensor_of_sfasfb_ptrs, # Pointer to scale factor pointers array | |
| cute_ptr_of_tensor_of_tensormap, # Pointer to tensormap buffer | |
| total_num_clusters, # Total number of CTAs to launch | |
| ) | |
| res = [] | |
| for i in range(num_groups): | |
| res.append(abc_tensors[i][2]) | |
| return res |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment