Created
May 3, 2024 18:58
-
-
Save xmfan/fd769f662107bc7b6eb8088595887f33 to your computer and use it in GitHub Desktop.
fused rms
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import triton | |
import triton.language as tl | |
@triton.autotune( | |
configs=[ | |
triton.Config({}, num_warps=1), | |
triton.Config({}, num_warps=2), | |
triton.Config({}, num_warps=4), | |
triton.Config({}, num_warps=8), | |
triton.Config({}, num_warps=16), | |
triton.Config({}, num_warps=32), | |
], | |
key=["N"], | |
) | |
@triton.jit | |
def _rms_norm_fwd_kernel( | |
X, | |
stride_x, | |
Y, | |
stride_y, | |
W, | |
Rstd, | |
eps, | |
M, # num rows | |
N, # num cols | |
block_N: tl.constexpr, | |
): | |
row = tl.program_id(0) | |
cols = tl.arange(0, block_N) | |
# Load input data and weights | |
mask = cols < N | |
x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) | |
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) | |
# Compute mean and variance | |
xbar = tl.where(cols < N, x, 0.0) | |
var = tl.sum(xbar * xbar, axis=0) / N | |
rstd = 1 / tl.sqrt(var + eps) | |
# Store the reciprocal standard deviation | |
tl.store(Rstd + row, rstd) | |
# Normalize and apply linear transformation | |
x_hat = x * rstd | |
y = x_hat * w | |
# Write output | |
tl.store(Y + row * stride_y + cols, y, mask=mask) | |
@triton.autotune( | |
configs=[ | |
triton.Config({}, num_warps=1), | |
triton.Config({}, num_warps=2), | |
triton.Config({}, num_warps=4), | |
triton.Config({}, num_warps=8), | |
triton.Config({}, num_warps=16), | |
triton.Config({}, num_warps=32), | |
], | |
key=["N"], | |
) | |
@triton.jit | |
def _rms_norm_bwd_kernel_sm( | |
X, | |
stride_x, | |
W, | |
DY, | |
stride_dy, | |
DX, | |
stride_dx, | |
Rstd, | |
DW, | |
eps, | |
M, # num rows | |
N, # num cols | |
rows_per_program, | |
block_N: tl.constexpr, | |
): | |
row_block_id = tl.program_id(0) | |
row_start = row_block_id * rows_per_program | |
cols = tl.arange(0, block_N) | |
mask = cols < N | |
# Load weights | |
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) | |
# Accumulate gradients for weights | |
dw = tl.zeros((block_N,), dtype=tl.float32) | |
row_end = min(row_start + rows_per_program, M) | |
for row in range(row_start, row_end): | |
# Load input, output gradient, and reciprocal standard deviation | |
x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) | |
dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) | |
rstd = tl.load(Rstd + row) | |
# Compute normalized input and gradients | |
x_hat = x * rstd | |
wdy = w * dy | |
dw += dy * x_hat | |
c1 = tl.sum(x_hat * wdy, axis=0) / N | |
dx = (wdy - x_hat * c1) * rstd | |
# Store input gradient | |
tl.store(DX + row * stride_dx + cols, dx, mask=mask) | |
# Store weight gradients | |
tl.store(DW + row_block_id * N + cols, dw, mask=mask) | |
class TritonFusedRMSNorm(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, x, weight, eps): | |
x_shape_start = x.shape | |
# Flatten input | |
x = x.view(-1, x.shape[-1]) | |
if x.stride(-1) != 1: | |
x = x.contiguous() | |
if weight.stride(-1) != 1: | |
weight = weight.contiguous() | |
M, N = x.shape | |
y = torch.empty_like(x) | |
rstd = torch.empty((M,), dtype=torch.float32, device=x.device) | |
max_size = 65536 // x.element_size() | |
block_N = min(max_size, triton.next_power_of_2(N)) | |
if N > block_N: | |
raise ValueError(f"N {N} must be <= {block_N=}") | |
grid = lambda meta: (M,) | |
_rms_norm_fwd_kernel[grid]( | |
x, | |
x.stride(0), | |
y, | |
y.stride(0), | |
weight, | |
rstd, | |
eps, | |
M, | |
N, | |
block_N, | |
) | |
ctx.eps = eps | |
ctx.save_for_backward(x, weight, rstd) | |
ctx.x_shape_start = x_shape_start | |
y = y.reshape(x_shape_start) | |
return y | |
@staticmethod | |
def backward(ctx, dy): | |
x, weight, rstd = ctx.saved_tensors | |
eps = ctx.eps | |
x_shape_start = ctx.x_shape_start | |
# Flatten input and output gradients | |
dy = dy.view(-1, dy.shape[-1]) | |
if dy.stride(-1) != 1: | |
dy = dy.contiguous() | |
M, N = dy.shape | |
dx = torch.empty_like(x) | |
dw = torch.empty_like(weight) | |
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count | |
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) | |
max_size = 65536 // x.element_size() | |
block_N = min(max_size, triton.next_power_of_2(N)) | |
rows_per_sm = math.ceil(M / sm_count) | |
if N > block_N: | |
raise ValueError(f"N {N} must be <= {block_N=}") | |
grid = lambda meta: (sm_count,) | |
_rms_norm_bwd_kernel_sm[grid]( | |
x, | |
x.stride(0), | |
weight, | |
dy, | |
dy.stride(0), | |
dx, | |
dx.stride(0), | |
rstd, | |
_dw, | |
eps, | |
M, | |
N, | |
rows_per_sm, | |
block_N, | |
) | |
dw = _dw.sum(0).to(weight.dtype) | |
dx = dx.view(x_shape_start) | |
return dx, dw, None | |
if __name__ == "__main__": | |
def fused_rms_norm_fn( | |
x, | |
weight, | |
eps=1e-6, | |
): | |
return TritonFusedRMSNorm.apply( | |
x, | |
weight, | |
eps, | |
) | |
def fn(x, weight): | |
x += 1 | |
new_weight = weight + 1 | |
return fused_rms_norm_fn(x, new_weight) | |
x = torch.randn(200, device="cuda") | |
weight = torch.nn.Parameter(torch.ones(200, device="cuda")) | |
compiled_fn = torch.compile(fn) | |
out = compiled_fn(x, weight) | |
loss = out.sum() | |
loss.backward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment