Skip to content

Instantly share code, notes, and snippets.

@zhuangh
Created November 22, 2023 07:52
Show Gist options
  • Save zhuangh/bd1ae531322a397458d867dc3067648e to your computer and use it in GitHub Desktop.
Save zhuangh/bd1ae531322a397458d867dc3067648e to your computer and use it in GitHub Desktop.
run_matmul_gtx1060.py
import torch
import triton
import triton.language as tl
import torch.nn.functional as F
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
offs_m = tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, accumulator)
X = torch.normal(0, 1, size=(1024, 1024), device='cuda')
Y = torch.empty_like(X)
def matmul(a, b, activation=""):
# Check constraints.
assert a.shape[1] == b.shape[0], f"Incompatible dimensions {a.shape[1]} != {b.shape[0]}"
assert a.is_contiguous(), "Matrix A must be contiguous"
assert b.is_contiguous(), "Matrix B must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
t = 16
compiled = matmul_kernel[grid](
a, b, c, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
#ACTIVATION=activation #
M, N, K,
t,t,t
)
print("IR", compiled.asm['ttir'])
print("TTGIR", compiled.asm['ttgir'])
print("LLIR", compiled.asm['llir'])
return c
Z = matmul(X, Y)
print(dir(matmul_kernel.cache))
with open("matmul_kernel.ptx", "w") as a:
print(list(matmul_kernel.cache[0].values())[0].asm['ptx'], file=a)
#
# @triton.testing.perf_report(
# triton.testing.Benchmark(
# x_names=['size'], # Argument names to use as an x-axis for the plot.
# x_vals=[2**i for i in range(4, 10, 1)], # Different possible values for `x_name`.
# x_log=True, # x axis is logarithmic.
# line_arg='provider', # Argument name whose value corresponds to a different line in the plot.
# line_vals=['triton', 'torch'], # Possible values for `line_arg`.
# line_names=['Triton', 'Torch'], # Label name for the lines.
# styles=[('blue', '-'), ('green', '-')], # Line styles.
# ylabel='GB/s', # Label name for the y-axis.
# plot_name='mat-mul-performance', # Name for the plot. Used also as a file name for saving the plot.
# args={}, # Values for function arguments not in `x_names` and `y_name`.
# ))
# def benchmark(size, provider):
# print("!!!!!!!!!!!", size)
# x = torch.rand((size,size), device='cuda', dtype=torch.float32)
# y = torch.rand((size,size), device='cuda', dtype=torch.float32)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(x, y), quantiles=quantiles)
# gbps = lambda ms: 12 * size / ms * 1e-6
# return gbps(ms), gbps(max_ms), gbps(min_ms)
#
# output_torch = x + y
# output_triton = matmul(x, y)
#
# benchmark.run(print_data=True, show_plots=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment