Skip to content

Instantly share code, notes, and snippets.

@scturtle
Created April 8, 2024 11:44
Show Gist options
  • Save scturtle/706fb1d77611232a8047ab7a20bc1e38 to your computer and use it in GitHub Desktop.
Save scturtle/706fb1d77611232a8047ab7a20bc1e38 to your computer and use it in GitHub Desktop.
how matmul is tiled in cuda
import numpy as np
from threading import Barrier, Thread
from collections import namedtuple
dim3 = namedtuple("dim3", ["x", "y", "z"], defaults=(1, 1))
TILE = 16
def cdiv(a, b):
return (a + b - 1) // b
def matmul_kernel(
blockIdx: dim3,
threadIdx: dim3,
blockDim: dim3,
sync: Barrier,
shared,
a,
b,
out,
):
h, k = a.shape
_, w = b.shape
tr = threadIdx.y
tc = threadIdx.x
r = blockIdx.y * blockDim.y + tr
c = blockIdx.x * blockDim.x + tc
ta = shared[0].reshape(TILE, TILE)
tb = shared[1].reshape(TILE, TILE)
res = 0.0
for i in range(cdiv(k, TILE)):
ta[tr][tc] = a[r][i * TILE + tc] if r < h and (i * TILE + tc) < k else 0.0
tb[tr][tc] = b[i * TILE + tr][c] if c < w and (i * TILE + tr) < k else 0.0
sync.wait()
for j in range(TILE):
res += ta[tr][j] * tb[j][tc]
sync.wait()
if r < h and c < w:
out[r][c] = res
def launch_kernel(f, blocks: dim3, tpb: dim3, *args):
for iby in range(blocks.y):
for ibx in range(blocks.x):
sync = Barrier(tpb.y * tpb.x)
shared = np.empty((2, TILE, TILE), dtype=np.float32)
threads = [
Thread(
target=f,
args=(dim3(ibx, iby), dim3(itx, ity), tpb, sync, shared, *args),
)
for ity in range(tpb.y)
for itx in range(tpb.x)
]
for t in threads:
t.start()
for t in threads:
t.join()
def matmul(a, b):
h, _ = a.shape
_, w = b.shape
out = np.empty((h, w), dtype=a.dtype)
blocks = dim3(cdiv(h, TILE), cdiv(w, TILE))
tpb = dim3(TILE, TILE)
launch_kernel(matmul_kernel, blocks, tpb, a, b, out)
return out
a = np.random.rand(20, 30)
b = np.random.rand(30, 20)
print(np.allclose(a @ b, matmul(a, b)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment