Skip to content

Instantly share code, notes, and snippets.

@xmfan
Created May 3, 2024 18:58
Show Gist options
  • Save xmfan/fd769f662107bc7b6eb8088595887f33 to your computer and use it in GitHub Desktop.
Save xmfan/fd769f662107bc7b6eb8088595887f33 to your computer and use it in GitHub Desktop.
fused rms
import math
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N"],
)
@triton.jit
def _rms_norm_fwd_kernel(
X,
stride_x,
Y,
stride_y,
W,
Rstd,
eps,
M, # num rows
N, # num cols
block_N: tl.constexpr,
):
row = tl.program_id(0)
cols = tl.arange(0, block_N)
# Load input data and weights
mask = cols < N
x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
# Compute mean and variance
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# Store the reciprocal standard deviation
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + row * stride_y + cols, y, mask=mask)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N"],
)
@triton.jit
def _rms_norm_bwd_kernel_sm(
X,
stride_x,
W,
DY,
stride_dy,
DX,
stride_dx,
Rstd,
DW,
eps,
M, # num rows
N, # num cols
rows_per_program,
block_N: tl.constexpr,
):
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
cols = tl.arange(0, block_N)
mask = cols < N
# Load weights
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
# Accumulate gradients for weights
dw = tl.zeros((block_N,), dtype=tl.float32)
row_end = min(row_start + rows_per_program, M)
for row in range(row_start, row_end):
# Load input, output gradient, and reciprocal standard deviation
x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32)
rstd = tl.load(Rstd + row)
# Compute normalized input and gradients
x_hat = x * rstd
wdy = w * dy
dw += dy * x_hat
c1 = tl.sum(x_hat * wdy, axis=0) / N
dx = (wdy - x_hat * c1) * rstd
# Store input gradient
tl.store(DX + row * stride_dx + cols, dx, mask=mask)
# Store weight gradients
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
class TritonFusedRMSNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, eps):
x_shape_start = x.shape
# Flatten input
x = x.view(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if weight.stride(-1) != 1:
weight = weight.contiguous()
M, N = x.shape
y = torch.empty_like(x)
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
max_size = 65536 // x.element_size()
block_N = min(max_size, triton.next_power_of_2(N))
if N > block_N:
raise ValueError(f"N {N} must be <= {block_N=}")
grid = lambda meta: (M,)
_rms_norm_fwd_kernel[grid](
x,
x.stride(0),
y,
y.stride(0),
weight,
rstd,
eps,
M,
N,
block_N,
)
ctx.eps = eps
ctx.save_for_backward(x, weight, rstd)
ctx.x_shape_start = x_shape_start
y = y.reshape(x_shape_start)
return y
@staticmethod
def backward(ctx, dy):
x, weight, rstd = ctx.saved_tensors
eps = ctx.eps
x_shape_start = ctx.x_shape_start
# Flatten input and output gradients
dy = dy.view(-1, dy.shape[-1])
if dy.stride(-1) != 1:
dy = dy.contiguous()
M, N = dy.shape
dx = torch.empty_like(x)
dw = torch.empty_like(weight)
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
max_size = 65536 // x.element_size()
block_N = min(max_size, triton.next_power_of_2(N))
rows_per_sm = math.ceil(M / sm_count)
if N > block_N:
raise ValueError(f"N {N} must be <= {block_N=}")
grid = lambda meta: (sm_count,)
_rms_norm_bwd_kernel_sm[grid](
x,
x.stride(0),
weight,
dy,
dy.stride(0),
dx,
dx.stride(0),
rstd,
_dw,
eps,
M,
N,
rows_per_sm,
block_N,
)
dw = _dw.sum(0).to(weight.dtype)
dx = dx.view(x_shape_start)
return dx, dw, None
if __name__ == "__main__":
def fused_rms_norm_fn(
x,
weight,
eps=1e-6,
):
return TritonFusedRMSNorm.apply(
x,
weight,
eps,
)
def fn(x, weight):
x += 1
new_weight = weight + 1
return fused_rms_norm_fn(x, new_weight)
x = torch.randn(200, device="cuda")
weight = torch.nn.Parameter(torch.ones(200, device="cuda"))
compiled_fn = torch.compile(fn)
out = compiled_fn(x, weight)
loss = out.sum()
loss.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment