Skip to content

Instantly share code, notes, and snippets.

@nil0x9
Last active May 20, 2025 08:37
Show Gist options
  • Save nil0x9/16956ab4b66fcfbf9f81a174fef6bf71 to your computer and use it in GitHub Desktop.
Save nil0x9/16956ab4b66fcfbf9f81a174fef6bf71 to your computer and use it in GitHub Desktop.
A quick implementation of flash-muon but without materializing lower triangular parts of matrix A and B in Newton-Schulz iteration (using triton)
#!/usr/bin/env python
# coding: utf-8
import pandas as pd
import triton
import triton.language as tl
import torch
from torch import Tensor
try:
from flash_muon import fast_newtonschulz as fast_newtonschulz_v1
except ImportError as e:
print("Failed to import fast_newtonschulz from flash_muon. Please ensure the module is installed:")
print("\tgit clone https://github.com/nil0x9/flash-muon.git && pip install -e flash-muon/")
import sys
sys.exit(1)
assert triton.__version__ >= '3.2.0', "This scripts requires triton version >= 3.2.0 to run."
assert torch.cuda.is_available(), "Need CUDA device to run!"
current_device = torch.cuda.current_device()
device_name = torch.cuda.get_device_name(current_device)
print(f"Current CUDA device: {device_name}")
print("The scripts takes abit long to run (autotuning for triton kernels). Set TRITON_PRINT_AUTOTUNING=1 to make autotuning verbal.")
def get_mmt_kernel_autotune_config():
return [triton.Config({'BLOCK_SIZE_M': blk_m, 'BLOCK_SIZE_K': blk_k, 'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps)
for blk_m in [32, 64, 128]
for blk_k in [32, 64]
for grp_sz in [8]
for n_stages in [3, 4, 5]
for n_warps in [2, 4, 8]
]
def get_sym_axpbxx_kernel_autotune_config():
return [triton.Config({'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps)
for grp_sz in [4, 8]
for n_stages in [1, 2, 3, 4, 5]
for n_warps in [1, 2, 4, 8]
]
def get_sym_aypbxy_kernel_autotune_config():
return [triton.Config({'BLOCK_SIZE_N': blk_n, 'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps)
for blk_n in [32, 64, 128]
for grp_sz in [4, 8]
for n_stages in [1, 2, 3, 4, 5]
for n_warps in [1, 2, 4, 8]
]
@triton.autotune(
configs=get_mmt_kernel_autotune_config(),
key=['M', 'K'],
)
@triton.jit
def mmt_kernel(
x, y,
M, K,
stride_xm, stride_xk,
stride_ym, stride_yn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
"""
Core kernel jit function of matmul_transpose that computes y = x @ x.T
The code is a simple adaptation from the triton `matmul` tutorial:
https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
if pid_m > pid_n:
return
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_k = tl.arange(0, BLOCK_SIZE_K)
# we use a & b ptrs to denote different rows of x.
a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
a_ptrs += BLOCK_SIZE_K * stride_xk
b_ptrs += BLOCK_SIZE_K * stride_xk
# use dtype.element_ty to accomodate different input datatypes as in cpp templates
# https://github.com/triton-lang/triton/issues/2252
c = accumulator.to(x.dtype.element_ty, fp_downcast_rounding='rtne')
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
tl.store(c_ptrs, c, mask=c_mask)
@triton.autotune(
configs=get_sym_axpbxx_kernel_autotune_config(),
key=['M'],
)
@triton.jit
def sym_axpbxx_kernel(
x, y,
alpha, beta,
M,
stride_xm, stride_xk,
stride_ym, stride_yn,
BLOCK_SIZE_M: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
"""
calculate y = alpha * x + beta * x @ x, where x is symetric matrix, alpha & beta are scalars
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
if pid_m > pid_n:
return
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_k = tl.arange(0, BLOCK_SIZE_M)
# we use a & b ptrs to denote different rows of x.
# a_ptrs_base = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
# b_ptrs_base = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
ktile = 0
for k in range(0, pid_m):
a_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xm[None, :] * stride_xk)
b_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xn[None, :] * stride_xk)
# print(ktile)
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
accumulator = tl.dot(tl.permute(a, (1, 0)), b, accumulator)
ktile += BLOCK_SIZE_M
for k in range(pid_m, pid_n+1):
a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk)
b_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xn[None, :] * stride_xk)
# print(ktile)
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
accumulator = tl.dot(a, b, accumulator)
ktile += BLOCK_SIZE_M
for k in range(pid_n+1, tl.cdiv(M, BLOCK_SIZE_M)):
a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk)
b_ptrs = x + (offs_xn[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk)
# print(ktile)
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0)
accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
ktile += BLOCK_SIZE_M
# use dtype.element_ty to accomodate different input datatypes as in cpp templates
# https://github.com/triton-lang/triton/issues/2252
c = accumulator.to(x.dtype.element_ty, fp_downcast_rounding='rtne')
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
a_tile_ptrs = x + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
a_tile = tl.load(a_tile_ptrs, mask=c_mask, other=0.0)
c = alpha * a_tile + beta * c
tl.store(c_ptrs, c, mask=c_mask)
@triton.autotune(
configs=get_sym_aypbxy_kernel_autotune_config(),
key=['M', 'N'],
)
@triton.jit
def sym_aypbxy_kernel(
x, y,
z,
alpha, beta,
M,
N,
stride_xm, stride_xk,
stride_ym, stride_yn,
stride_zm, stride_zn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
"""
calculate y = alpha * y + beta * x @ y, where x is a symetric matrix, y is a normal matrix, alpha & beta are scalars
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_M)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
ktile = 0
for k in range(0, pid_m):
a_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xm[None, :] * stride_xk)
b_ptrs = y + ((offs_k[:, None] + ktile) * stride_ym + offs_yn[None, :] * stride_yn)
# print(ktile)
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0)
accumulator = tl.dot(tl.permute(a, (1, 0)), b, accumulator)
ktile += BLOCK_SIZE_M
for k in range(pid_m, tl.cdiv(M, BLOCK_SIZE_M)):
a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk)
b_ptrs = y + ((offs_k[:, None] + ktile) * stride_ym + offs_yn[None, :] * stride_yn)
# print(ktile)
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0)
accumulator = tl.dot(a, b, accumulator)
ktile += BLOCK_SIZE_M
# use dtype.element_ty to accomodate different input datatypes as in cpp templates
# https://github.com/triton-lang/triton/issues/2252
c = accumulator.to(x.dtype.element_ty)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
y_tile_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
c_ptrs = z + stride_zm * offs_cm[:, None] + stride_zn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
a_tile = tl.load(y_tile_ptrs, mask=c_mask, other=0.0)
c = alpha * a_tile + beta * c
tl.store(c_ptrs, c, mask=c_mask)
def fast_ns_iter(X, a=3.4445, b=-4.7750, c=2.0315):
X = X.contiguous()
M, K = X.shape
A = torch.empty((M, M), device=X.device, dtype=X.dtype)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(M, META['BLOCK_SIZE_M']), )
mmt_kernel[grid](
X,
A,
M,
K,
X.stride(0),
X.stride(1),
A.stride(0),
A.stride(1)
)
BLOCK_SIZE_M = mmt_kernel.best_config.kwargs['BLOCK_SIZE_M']
B = torch.empty((M, M), device=X.device, dtype=X.dtype)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(M, META['BLOCK_SIZE_M']), )
sym_axpbxx_kernel[grid](
A,
B,
b,
c,
M,
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M
)
N = K # TODO: rename
X_ = torch.empty((M, K), device=X.device, dtype=X.dtype)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
sym_aypbxy_kernel[grid](
B,
X,
X_,
a,
1.0,
M,
N,
B.stride(0),
B.stride(1),
X.stride(0),
X.stride(1),
X_.stride(0),
X_.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M
)
return A, B, X_
def ref_ns_iter(X, a=3.4445, b=-4.7750, c=2.0315):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
return A, B, X
x = torch.randn(512, 512).cuda().half()/100
A, B, X_ = fast_ns_iter(x)
A_ref, B_ref, X_ref = ref_ns_iter(x)
for (res, ref, name) in [(A, A_ref, 'A'), (B, B_ref, 'B'), (X_, X_ref, 'X')]:
mask = torch.isclose(res, ref, rtol=1e-2, atol=1e-2)
if name != 'X':
size = mask.size(0)
mask |= torch.tril(torch.ones(size, size, dtype=torch.bool, device=mask.device))
assert torch.all(mask), f"Results not match for {name}"
print(f"Results match for Tensor {name}")
def newtonschulz_base(G: Tensor, steps: int = 5) -> Tensor:
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
def fast_newtonschulz_v2(G: Tensor, steps: int = 5) -> Tensor:
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
_, _, X = fast_ns_iter(X, a, b, c)
if G.size(-2) > G.size(-1):
X = X.mT
return X
for N in [1024, 2048, 4096, 8192]:
x = torch.randn(N, N, device='cuda')
base = triton.testing.do_bench(lambda: newtonschulz_base(x))
flash_v1 = triton.testing.do_bench(lambda: fast_newtonschulz_v1(x, steps=5))
flash_v2 = triton.testing.do_bench(lambda: fast_newtonschulz_v2(x))
print(f"Dimension: {N:<5} | Torch: {base:>10.3f} ms | Flash V1: {flash_v1:>10.3f} ms| Flash V2: {flash_v2:>10.3f} ms")
"""Example output:
Current CUDA device: NVIDIA A100-PCIE-40GB
The scripts takes abit long to run (autotuning for triton kernels). Set TRITON_PRINT_AUTOTUNING=1 to make autotuning verbal.
Results match for Tensor A
Results match for Tensor B
Results match for Tensor X
Dimension: 1024 | Torch: 0.703 ms | Flash V1: 1.355 ms| Flash V2: 1.090 ms
Dimension: 2048 | Torch: 2.272 ms | Flash V1: 1.611 ms| Flash V2: 2.507 ms
Dimension: 4096 | Torch: 10.946 ms | Flash V1: 8.413 ms| Flash V2: 8.147 ms
Dimension: 8192 | Torch: 79.655 ms | Flash V1: 59.229 ms| Flash V2: 62.253 ms
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment