Last active
May 20, 2025 08:37
-
-
Save nil0x9/16956ab4b66fcfbf9f81a174fef6bf71 to your computer and use it in GitHub Desktop.
A quick implementation of flash-muon but without materializing lower triangular parts of matrix A and B in Newton-Schulz iteration (using triton)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
import pandas as pd | |
import triton | |
import triton.language as tl | |
import torch | |
from torch import Tensor | |
try: | |
from flash_muon import fast_newtonschulz as fast_newtonschulz_v1 | |
except ImportError as e: | |
print("Failed to import fast_newtonschulz from flash_muon. Please ensure the module is installed:") | |
print("\tgit clone https://github.com/nil0x9/flash-muon.git && pip install -e flash-muon/") | |
import sys | |
sys.exit(1) | |
assert triton.__version__ >= '3.2.0', "This scripts requires triton version >= 3.2.0 to run." | |
assert torch.cuda.is_available(), "Need CUDA device to run!" | |
current_device = torch.cuda.current_device() | |
device_name = torch.cuda.get_device_name(current_device) | |
print(f"Current CUDA device: {device_name}") | |
print("The scripts takes abit long to run (autotuning for triton kernels). Set TRITON_PRINT_AUTOTUNING=1 to make autotuning verbal.") | |
def get_mmt_kernel_autotune_config(): | |
return [triton.Config({'BLOCK_SIZE_M': blk_m, 'BLOCK_SIZE_K': blk_k, 'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps) | |
for blk_m in [32, 64, 128] | |
for blk_k in [32, 64] | |
for grp_sz in [8] | |
for n_stages in [3, 4, 5] | |
for n_warps in [2, 4, 8] | |
] | |
def get_sym_axpbxx_kernel_autotune_config(): | |
return [triton.Config({'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps) | |
for grp_sz in [4, 8] | |
for n_stages in [1, 2, 3, 4, 5] | |
for n_warps in [1, 2, 4, 8] | |
] | |
def get_sym_aypbxy_kernel_autotune_config(): | |
return [triton.Config({'BLOCK_SIZE_N': blk_n, 'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps) | |
for blk_n in [32, 64, 128] | |
for grp_sz in [4, 8] | |
for n_stages in [1, 2, 3, 4, 5] | |
for n_warps in [1, 2, 4, 8] | |
] | |
@triton.autotune( | |
configs=get_mmt_kernel_autotune_config(), | |
key=['M', 'K'], | |
) | |
@triton.jit | |
def mmt_kernel( | |
x, y, | |
M, K, | |
stride_xm, stride_xk, | |
stride_ym, stride_yn, | |
BLOCK_SIZE_M: tl.constexpr, | |
BLOCK_SIZE_K: tl.constexpr, | |
GROUP_SIZE_M: tl.constexpr | |
): | |
""" | |
Core kernel jit function of matmul_transpose that computes y = x @ x.T | |
The code is a simple adaptation from the triton `matmul` tutorial: | |
https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html | |
""" | |
pid = tl.program_id(axis=0) | |
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) | |
num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
group_id = pid // num_pid_in_group | |
first_pid_m = group_id * GROUP_SIZE_M | |
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) | |
pid_n = (pid % num_pid_in_group) // group_size_m | |
if pid_m > pid_n: | |
return | |
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
offs_k = tl.arange(0, BLOCK_SIZE_K) | |
# we use a & b ptrs to denote different rows of x. | |
a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) | |
b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) | |
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) | |
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | |
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | |
b = tl.load(b_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | |
accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) | |
a_ptrs += BLOCK_SIZE_K * stride_xk | |
b_ptrs += BLOCK_SIZE_K * stride_xk | |
# use dtype.element_ty to accomodate different input datatypes as in cpp templates | |
# https://github.com/triton-lang/triton/issues/2252 | |
c = accumulator.to(x.dtype.element_ty, fp_downcast_rounding='rtne') | |
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] | |
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) | |
tl.store(c_ptrs, c, mask=c_mask) | |
@triton.autotune( | |
configs=get_sym_axpbxx_kernel_autotune_config(), | |
key=['M'], | |
) | |
@triton.jit | |
def sym_axpbxx_kernel( | |
x, y, | |
alpha, beta, | |
M, | |
stride_xm, stride_xk, | |
stride_ym, stride_yn, | |
BLOCK_SIZE_M: tl.constexpr, | |
GROUP_SIZE_M: tl.constexpr | |
): | |
""" | |
calculate y = alpha * x + beta * x @ x, where x is symetric matrix, alpha & beta are scalars | |
""" | |
pid = tl.program_id(axis=0) | |
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) | |
num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
group_id = pid // num_pid_in_group | |
first_pid_m = group_id * GROUP_SIZE_M | |
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) | |
pid_n = (pid % num_pid_in_group) // group_size_m | |
if pid_m > pid_n: | |
return | |
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
offs_k = tl.arange(0, BLOCK_SIZE_M) | |
# we use a & b ptrs to denote different rows of x. | |
# a_ptrs_base = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) | |
# b_ptrs_base = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) | |
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) | |
ktile = 0 | |
for k in range(0, pid_m): | |
a_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xm[None, :] * stride_xk) | |
b_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xn[None, :] * stride_xk) | |
# print(ktile) | |
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) | |
b = tl.load(b_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) | |
accumulator = tl.dot(tl.permute(a, (1, 0)), b, accumulator) | |
ktile += BLOCK_SIZE_M | |
for k in range(pid_m, pid_n+1): | |
a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) | |
b_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xn[None, :] * stride_xk) | |
# print(ktile) | |
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) | |
b = tl.load(b_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) | |
accumulator = tl.dot(a, b, accumulator) | |
ktile += BLOCK_SIZE_M | |
for k in range(pid_n+1, tl.cdiv(M, BLOCK_SIZE_M)): | |
a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) | |
b_ptrs = x + (offs_xn[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) | |
# print(ktile) | |
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) | |
b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0) | |
accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) | |
ktile += BLOCK_SIZE_M | |
# use dtype.element_ty to accomodate different input datatypes as in cpp templates | |
# https://github.com/triton-lang/triton/issues/2252 | |
c = accumulator.to(x.dtype.element_ty, fp_downcast_rounding='rtne') | |
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
a_tile_ptrs = x + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] | |
c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] | |
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) | |
a_tile = tl.load(a_tile_ptrs, mask=c_mask, other=0.0) | |
c = alpha * a_tile + beta * c | |
tl.store(c_ptrs, c, mask=c_mask) | |
@triton.autotune( | |
configs=get_sym_aypbxy_kernel_autotune_config(), | |
key=['M', 'N'], | |
) | |
@triton.jit | |
def sym_aypbxy_kernel( | |
x, y, | |
z, | |
alpha, beta, | |
M, | |
N, | |
stride_xm, stride_xk, | |
stride_ym, stride_yn, | |
stride_zm, stride_zn, | |
BLOCK_SIZE_M: tl.constexpr, | |
BLOCK_SIZE_N: tl.constexpr, | |
GROUP_SIZE_M: tl.constexpr | |
): | |
""" | |
calculate y = alpha * y + beta * x @ y, where x is a symetric matrix, y is a normal matrix, alpha & beta are scalars | |
""" | |
pid = tl.program_id(axis=0) | |
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | |
num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
group_id = pid // num_pid_in_group | |
first_pid_m = group_id * GROUP_SIZE_M | |
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) | |
pid_n = (pid % num_pid_in_group) // group_size_m | |
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | |
offs_k = tl.arange(0, BLOCK_SIZE_M) | |
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
ktile = 0 | |
for k in range(0, pid_m): | |
a_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xm[None, :] * stride_xk) | |
b_ptrs = y + ((offs_k[:, None] + ktile) * stride_ym + offs_yn[None, :] * stride_yn) | |
# print(ktile) | |
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) | |
b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0) | |
accumulator = tl.dot(tl.permute(a, (1, 0)), b, accumulator) | |
ktile += BLOCK_SIZE_M | |
for k in range(pid_m, tl.cdiv(M, BLOCK_SIZE_M)): | |
a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) | |
b_ptrs = y + ((offs_k[:, None] + ktile) * stride_ym + offs_yn[None, :] * stride_yn) | |
# print(ktile) | |
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) | |
b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0) | |
accumulator = tl.dot(a, b, accumulator) | |
ktile += BLOCK_SIZE_M | |
# use dtype.element_ty to accomodate different input datatypes as in cpp templates | |
# https://github.com/triton-lang/triton/issues/2252 | |
c = accumulator.to(x.dtype.element_ty) | |
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
y_tile_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] | |
c_ptrs = z + stride_zm * offs_cm[:, None] + stride_zn * offs_cn[None, :] | |
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | |
a_tile = tl.load(y_tile_ptrs, mask=c_mask, other=0.0) | |
c = alpha * a_tile + beta * c | |
tl.store(c_ptrs, c, mask=c_mask) | |
def fast_ns_iter(X, a=3.4445, b=-4.7750, c=2.0315): | |
X = X.contiguous() | |
M, K = X.shape | |
A = torch.empty((M, M), device=X.device, dtype=X.dtype) | |
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(M, META['BLOCK_SIZE_M']), ) | |
mmt_kernel[grid]( | |
X, | |
A, | |
M, | |
K, | |
X.stride(0), | |
X.stride(1), | |
A.stride(0), | |
A.stride(1) | |
) | |
BLOCK_SIZE_M = mmt_kernel.best_config.kwargs['BLOCK_SIZE_M'] | |
B = torch.empty((M, M), device=X.device, dtype=X.dtype) | |
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(M, META['BLOCK_SIZE_M']), ) | |
sym_axpbxx_kernel[grid]( | |
A, | |
B, | |
b, | |
c, | |
M, | |
A.stride(0), | |
A.stride(1), | |
B.stride(0), | |
B.stride(1), | |
BLOCK_SIZE_M=BLOCK_SIZE_M | |
) | |
N = K # TODO: rename | |
X_ = torch.empty((M, K), device=X.device, dtype=X.dtype) | |
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) | |
sym_aypbxy_kernel[grid]( | |
B, | |
X, | |
X_, | |
a, | |
1.0, | |
M, | |
N, | |
B.stride(0), | |
B.stride(1), | |
X.stride(0), | |
X.stride(1), | |
X_.stride(0), | |
X_.stride(1), | |
BLOCK_SIZE_M=BLOCK_SIZE_M | |
) | |
return A, B, X_ | |
def ref_ns_iter(X, a=3.4445, b=-4.7750, c=2.0315): | |
A = X @ X.T | |
B = b * A + c * A @ A | |
X = a * X + B @ X | |
return A, B, X | |
x = torch.randn(512, 512).cuda().half()/100 | |
A, B, X_ = fast_ns_iter(x) | |
A_ref, B_ref, X_ref = ref_ns_iter(x) | |
for (res, ref, name) in [(A, A_ref, 'A'), (B, B_ref, 'B'), (X_, X_ref, 'X')]: | |
mask = torch.isclose(res, ref, rtol=1e-2, atol=1e-2) | |
if name != 'X': | |
size = mask.size(0) | |
mask |= torch.tril(torch.ones(size, size, dtype=torch.bool, device=mask.device)) | |
assert torch.all(mask), f"Results not match for {name}" | |
print(f"Results match for Tensor {name}") | |
def newtonschulz_base(G: Tensor, steps: int = 5) -> Tensor: | |
assert G.ndim >= 2 | |
a, b, c = (3.4445, -4.7750, 2.0315) | |
X = G.bfloat16() | |
if G.size(-2) > G.size(-1): | |
X = X.mT | |
# Ensure spectral norm is at most 1 | |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) | |
# Perform the NS iterations | |
for _ in range(steps): | |
A = X @ X.T | |
B = b * A + c * A @ A | |
X = a * X + B @ X | |
if G.size(-2) > G.size(-1): | |
X = X.mT | |
return X | |
def fast_newtonschulz_v2(G: Tensor, steps: int = 5) -> Tensor: | |
assert G.ndim >= 2 | |
a, b, c = (3.4445, -4.7750, 2.0315) | |
X = G.bfloat16() | |
if G.size(-2) > G.size(-1): | |
X = X.mT | |
# Ensure spectral norm is at most 1 | |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) | |
# Perform the NS iterations | |
for _ in range(steps): | |
_, _, X = fast_ns_iter(X, a, b, c) | |
if G.size(-2) > G.size(-1): | |
X = X.mT | |
return X | |
for N in [1024, 2048, 4096, 8192]: | |
x = torch.randn(N, N, device='cuda') | |
base = triton.testing.do_bench(lambda: newtonschulz_base(x)) | |
flash_v1 = triton.testing.do_bench(lambda: fast_newtonschulz_v1(x, steps=5)) | |
flash_v2 = triton.testing.do_bench(lambda: fast_newtonschulz_v2(x)) | |
print(f"Dimension: {N:<5} | Torch: {base:>10.3f} ms | Flash V1: {flash_v1:>10.3f} ms| Flash V2: {flash_v2:>10.3f} ms") | |
"""Example output: | |
Current CUDA device: NVIDIA A100-PCIE-40GB | |
The scripts takes abit long to run (autotuning for triton kernels). Set TRITON_PRINT_AUTOTUNING=1 to make autotuning verbal. | |
Results match for Tensor A | |
Results match for Tensor B | |
Results match for Tensor X | |
Dimension: 1024 | Torch: 0.703 ms | Flash V1: 1.355 ms| Flash V2: 1.090 ms | |
Dimension: 2048 | Torch: 2.272 ms | Flash V1: 1.611 ms| Flash V2: 2.507 ms | |
Dimension: 4096 | Torch: 10.946 ms | Flash V1: 8.413 ms| Flash V2: 8.147 ms | |
Dimension: 8192 | Torch: 79.655 ms | Flash V1: 59.229 ms| Flash V2: 62.253 ms | |
""" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment