Created
September 19, 2023 23:56
-
-
Save HDCharles/44952fc614a75ad083f5054d50ef5341 to your computer and use it in GitHub Desktop.
not using block pointers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@triton.jit | |
def matmul_kernel_with_block_pointers( | |
# Pointers to matrices | |
a_ptr, b_ptr, c_ptr, s1_ptr, s2_ptr, | |
# Matrix dimensions | |
M, N, K, | |
# The stride variables represent how much to increase the ptr by when moving by 1 | |
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` | |
# by to get the element one row down (A has M rows). | |
stride_am, stride_ak, | |
stride_bk, stride_bn, | |
stride_cm, stride_cn, | |
stride_s1m, stride_s1n, | |
stride_s2m, stride_s2n, | |
# Meta-parameters | |
BLOCK_SIZE_M: tl.constexpr, | |
BLOCK_SIZE_N: tl.constexpr, | |
BLOCK_SIZE_K: tl.constexpr, | |
GROUP_SIZE_M: tl.constexpr | |
): | |
"""Kernel for computing the matmul C = A x B. | |
A has shape (M, K), B has shape (K, N) and C has shape (M, N) | |
""" | |
# ----------------------------------------------------------- | |
# Map program ids `pid` to the block of C it should compute. | |
# This is done in a grouped ordering to promote L2 data reuse. | |
# See above `L2 Cache Optimizations` section for details. | |
pid = tl.program_id(axis=0) | |
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | |
num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
group_id = pid // num_pid_in_group | |
first_pid_m = group_id * GROUP_SIZE_M | |
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
pid_m = first_pid_m + (pid % group_size_m) | |
pid_n = (pid % num_pid_in_group) // group_size_m | |
# ---------------------------------------------------------- | |
# Create pointers for the first blocks of A and B. | |
# We will advance this pointer as we move in the K direction | |
# and accumulate | |
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers | |
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers | |
# See above `Pointer Arithmetics` section for details | |
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | |
offs_k = tl.arange(0, BLOCK_SIZE_K) | |
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | |
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | |
# ----------------------------------------------------------- | |
# Iterate to compute a block of the C matrix. | |
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block | |
# of fp32 values for higher accuracy. | |
# `accumulator` will be converted back to fp16 after the loop. | |
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | |
# Load the next block of A and B, generate a mask by checking the K dimension. | |
# If it is out of bounds, set it to 0. | |
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | |
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | |
# We accumulate along the K dimension. | |
accumulator += tl.dot(a, b) | |
# Advance the ptrs to the next K block. | |
a_ptrs += BLOCK_SIZE_K * stride_ak | |
b_ptrs += BLOCK_SIZE_K * stride_bk | |
c = accumulator | |
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
s1_ptrs = s1_ptr + offs_m[:, None] * stride_s1m + offs_n[None, :] * stride_s1n | |
s2_ptrs = s2_ptr + offs_m[:, None] * stride_s2m + offs_n[None, :] * stride_s2n | |
s1 = tl.load(s1_ptrs) | |
s2 = tl.load(s2_ptrs) | |
c = c * s1 * s2 | |
c = c.to(tl.float16) | |
# ----------------------------------------------------------- | |
# Write back the block of the output matrix C with masks. | |
c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :] | |
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) | |
# Epilogue | |
tl.store(c_ptrs, c, mask=c_mask) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment