Skip to content

Instantly share code, notes, and snippets.

Forked from MekkCyber/
Created February 9, 2025 09:29
Show Gist options
  • Save basavyr/e84ee1b2d57ec80525cf93710c6b91a2 to your computer and use it in GitHub Desktop.
Save basavyr/e84ee1b2d57ec80525cf93710c6b91a2 to your computer and use it in GitHub Desktop.
Kernel for matmul while unpacking int2 weights
import torch
import triton
import triton.language as tl
def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor:
values_per_item = 8 // bits
packed_shape = packed.shape
if len(packed_shape) == 1:
original_row_dim = packed_shape[0] * values_per_item
unpacked_shape = (original_row_dim,)
original_row_dim = packed_shape[0] * values_per_item
unpacked_shape = (original_row_dim, *packed_shape[1:])
unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
for i in range(values_per_item):
start = i * packed_shape[0]
end = start + packed_shape[0]
mask = (3 << (2 * i))
unpacked[start:end] = (packed & mask) >> (2 * i)
unpacked = - 1
return unpacked
def pack_weights(intweights: torch.Tensor, bits: int) -> torch.Tensor:
intweights += 1
original_shape = intweights.shape
values_per_item = 8 // bits
row_dim = (original_shape[0] + values_per_item - 1) // values_per_item
if len(original_shape) == 1:
packed_tensor_shape = (row_dim,)
packed_tensor_shape = (row_dim, *original_shape[1:])
packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
unpacked =
def lshift(t: torch.Tensor, bits: int):
return t << bits
it = min(values_per_item, (original_shape[0] // row_dim) + 1)
for i in range(it):
start = i * row_dim
end = min(start + row_dim, original_shape[0])
packed[: (end - start)] |= lshift(unpacked[start:end], bits * i)
return packed
def get_cuda_autotune_config():
return [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=4,
key=['M', 'N', 'K'],
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, #
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
for i in range(4) :
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K) ):
k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j
# BLOCK_SIZE_K must be a divisor of K / 4
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K // 4 - j * BLOCK_SIZE_K, other=0)
mask = 3<<(2*i)
b = ((b_uint8 & mask) >> (2*i))
# We accumulate along the K dimension.
tensor_full = tl.full((1,), 1, dtype=tl.int8)
accumulator +=, ( - tensor_full), out_dtype=tl.int32)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N), c, mask=c_mask)
def matmul(a, b):
assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
return c
size = 2048
ht = torch.randint(-127, 127, (13, 1, size*4), device='cuda', dtype=torch.int8)
u = torch.randint(0,255,(size*4, size), device='cuda', dtype=torch.uint8)
B, M, N = ht.size()
triton_output = matmul(ht.view(B*M, N), u.T.contiguous()).view(B, M, -1)
assert (pack_weights(unpack_weights(u.T), 2) == u.T).all()
unpacked = unpack_weights(u.T, bits=2).T
torch_output = torch.matmul(ht.half(), unpacked.T.contiguous())
print("triton = ",triton_output)
print("torch = ",torch_output)
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=1e-2):
print("Results match")
print("Results differ")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment