-
-
Save basavyr/e84ee1b2d57ec80525cf93710c6b91a2 to your computer and use it in GitHub Desktop.
Kernel for matmul while unpacking int2 weights
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor: | |
values_per_item = 8 // bits | |
packed_shape = packed.shape | |
if len(packed_shape) == 1: | |
original_row_dim = packed_shape[0] * values_per_item | |
unpacked_shape = (original_row_dim,) | |
else: | |
original_row_dim = packed_shape[0] * values_per_item | |
unpacked_shape = (original_row_dim, *packed_shape[1:]) | |
unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8) | |
for i in range(values_per_item): | |
start = i * packed_shape[0] | |
end = start + packed_shape[0] | |
mask = (3 << (2 * i)) | |
unpacked[start:end] = (packed & mask) >> (2 * i) | |
unpacked = unpacked.to(torch.float16) - 1 | |
return unpacked | |
def pack_weights(intweights: torch.Tensor, bits: int) -> torch.Tensor: | |
intweights += 1 | |
original_shape = intweights.shape | |
values_per_item = 8 // bits | |
row_dim = (original_shape[0] + values_per_item - 1) // values_per_item | |
if len(original_shape) == 1: | |
packed_tensor_shape = (row_dim,) | |
else: | |
packed_tensor_shape = (row_dim, *original_shape[1:]) | |
packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8) | |
unpacked = intweights.to(torch.uint8) | |
def lshift(t: torch.Tensor, bits: int): | |
return t << bits | |
it = min(values_per_item, (original_shape[0] // row_dim) + 1) | |
for i in range(it): | |
start = i * row_dim | |
end = min(start + row_dim, original_shape[0]) | |
packed[: (end - start)] |= lshift(unpacked[start:end], bits * i) | |
return packed | |
def get_cuda_autotune_config(): | |
return [ | |
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, | |
num_warps=8), | |
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | |
num_warps=4), | |
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | |
num_warps=4), | |
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | |
num_warps=4), | |
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | |
num_warps=4), | |
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | |
num_warps=4), | |
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, | |
num_warps=8), | |
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, | |
num_warps=8), | |
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, | |
num_warps=4), | |
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, | |
num_warps=4), | |
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, | |
num_warps=4), | |
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, | |
num_warps=4), | |
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, | |
num_warps=4), | |
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, | |
num_warps=4), | |
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=4, | |
num_warps=4) | |
] | |
@triton.autotune( | |
configs=get_cuda_autotune_config(), | |
key=['M', 'N', 'K'], | |
) | |
@triton.jit | |
def matmul_kernel( | |
a_ptr, b_ptr, c_ptr, | |
M, N, K, | |
stride_am, stride_ak, # | |
stride_bk, stride_bn, # | |
stride_cm, stride_cn, | |
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, | |
GROUP_SIZE_M: tl.constexpr, # | |
): | |
pid = tl.program_id(axis=0) | |
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | |
num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
group_id = pid // num_pid_in_group | |
first_pid_m = group_id * GROUP_SIZE_M | |
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) | |
pid_n = (pid % num_pid_in_group) // group_size_m | |
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | |
offs_k = tl.arange(0, BLOCK_SIZE_K) | |
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | |
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | |
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) | |
for i in range(4) : | |
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | |
for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K) ): | |
k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j | |
# BLOCK_SIZE_K must be a divisor of K / 4 | |
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0) | |
b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K // 4 - j * BLOCK_SIZE_K, other=0) | |
mask = 3<<(2*i) | |
b = ((b_uint8 & mask) >> (2*i)) | |
# We accumulate along the K dimension. | |
tensor_full = tl.full((1,), 1, dtype=tl.int8) | |
accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32) | |
# Advance the ptrs to the next K block. | |
a_ptrs += BLOCK_SIZE_K * stride_ak | |
b_ptrs += BLOCK_SIZE_K * stride_bk | |
c = accumulator | |
# ----------------------------------------------------------- | |
# Write back the block of the output matrix C with masks. | |
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | |
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | |
tl.store(c_ptrs, c, mask=c_mask) | |
def matmul(a, b): | |
assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed" | |
assert a.is_contiguous(), "Matrix A must be contiguous" | |
M, K = a.shape | |
_, N = b.shape | |
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | |
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) | |
matmul_kernel[grid]( | |
a, b, c, # | |
M, N, K, # | |
a.stride(0), a.stride(1), # | |
b.stride(0), b.stride(1), # | |
c.stride(0), c.stride(1), # | |
) | |
return c | |
size = 2048 | |
ht = torch.randint(-127, 127, (13, 1, size*4), device='cuda', dtype=torch.int8) | |
u = torch.randint(0,255,(size*4, size), device='cuda', dtype=torch.uint8) | |
B, M, N = ht.size() | |
triton_output = matmul(ht.view(B*M, N), u.T.contiguous()).view(B, M, -1) | |
assert (pack_weights(unpack_weights(u.T), 2) == u.T).all() | |
unpacked = unpack_weights(u.T, bits=2).T | |
torch_output = torch.matmul(ht.half(), unpacked.T.contiguous()) | |
print("triton = ",triton_output) | |
print("torch = ",torch_output) | |
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=1e-2): | |
print("Results match") | |
else: | |
print("Results differ") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment