This is an optimized implementation of RMSNorm inference kernel using Triton, a Python-based GPU programming library. This implementation is a modified version of the excellent RMSNorm kernel from the Unsloth project.
It has two improvements:
int64
for pointer offset: We useint64
instead of the defaultint32
to compute the pointer offset value. This change prevents overflow when dealing with large sequence lengths where the offset exceeds the maximumint32
value (2B).- In-place computation: Our kernel writes the result back to the input buffer, eliminating the need for additional memory allocation. This approach halves the memory usage compared to traditional implementations that use a separate output buffer.
import torch
import triton