#gpu #kernel #rust
- GPU kernels in Rust
- Comptime
- Automatic vectorization
- Instruction and shape specialization
- Loop unrolling
- Autotuning
- WGSL - WebGPU Shading Language
- GLSL - OpenGL
- HLSL - High-level shading language
- MSL - Metal Shading Language
- CUDA
- ROCm
- SYCL
- Rust -> WGSL
- Rust -> CUDA
CubeCL provides runtimes (cubecl_wgpu
and cubecl_cuda
) that are built on top of the following backends: Wgpu
and Cuda
.
From my understanding, the current implementation includes the following constructs: ComputeClient
, ComputeServer
, and a Channel
, which serves as the abstraction for sending requests from the client to the server.
Instantiating a ComputeClient
involves two steps:
- Setting up the necessary data structures for each backend (e.g.,
wgpu_setup
forWgpu
). - Creating a client using the data structures from the setup, along with instantiating a
MemoryManagement
type to manage GPU memory allocation and deallocation strategies.
The client essentially wraps a Channel
and a FeatureSet
, which is a list of features supported by each runtime.
Once we have a ComputeClient
, we can perform various tasks, such as creating or accessing resources (e.g., GPU buffers) and executing kernels. Note that invoking methods on the client will eventually route them to the ComputeServer
, which holds the necessary Wgpu
structures to actually create and access these resources.
use cubecl::prelude::*;
#[cube(launch_unchecked)]
fn gelu_array<F: Float>(input: &Array<F>, output: &mut Array<F>) {
if ABSOLUTE_POS < input.len() {
output[ABSOLUTE_POS] = gelu_scalar::<F>(input[ABSOLUTE_POS]);
}
}
#[cube]
fn gelu_scalar<F: Float>(x: F) -> F {
x * (F::erf(x / F::sqrt(2.0.into())) + 1.0) / 2.0
}
CubeCL's unique selling point (USP) is its ability to write GPU kernels in Rust, as demonstrated above. However, there are a few things to keep in mind:
- All types used in a CubeCL function must implement the
CubeType
trait. In the example above, bothF
andArray<F>
are CubeCL types. They both implement theCubeType
trait, whileF
also implements theFloat
trait. - CubeCL kernels are procedural macros that expand into Rust functions. These generated functions, which are semantically similar to the original ones, produce the Intermediate Representation (IR) when invoked.
Key point: Instead of directly generating the IR, the macro first creates a new Rust function.
The Flow:
- In the above example, the
CubeCL
function annotated with the#[cube(launch_unchecked)]
macro expands into a module containing aGeluArray
struct that implements theKernel
trait.
pub struct GeluArray<F: Float, __R: cubecl::prelude::Runtime> {
settings: cubecl::prelude::KernelSettings,
__ty: ::core::marker::PhantomData<(__R, F)>,
}
- The
GeluArray
struct holds theKernelSettings
struct. KernelSettings
allows us to configure various parameters, including the vectorization factor for kernel inputs and outputs.- Once we configure our
KernelSettings
, we instantiate aKernelLauncher
and register the associated kernel inputs and outputs for the kernel launch. - Kernel launching involves several levels of indirection:
- The
KernelLauncher
invokes theComputeClient
'sexecute
method to initiate kernel execution. - This method uses a
Channel
to route the call to theComputeServer
(in our case, theWgpuServer
), which executes the kernel with the provided bindings.
- The
- Kernel execution involves preparing the pipeline state.
- At this stage, the kernel is compiled into source code (i.e., WGSL).
- Remember, the kernel is simply our
GeluArray
struct, which implements theKernel
trait. The Kernel trait requires two methods:
pub trait Kernel: Send + Sync + 'static + Sized {
/// Convert to a kernel definition.
fn define(&self) -> KernelDefinition;
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> KernelId {
KernelId::new::<Self>()
}
}
Vectorization factor: For example,
Elem::Float(FloatKind)
with a vectorization factor of 4 represents a 4-element vector of floating-point numbers, which could be processed in a SIMD manner.
Binding struct: It's a memory binding, which connects the tensor handle and the actual memory (storage) on the compute server.
Kernel preparation involves two main steps:
- Kernel Expansion
- Kernel Definition
In the example above:
-
Kernel definition begins with instantiating the
KernelBuilder
struct and populating it with the kernel’s inputs, outputs, context, and the number of inputs and outputs. -
Two ordered maps are required to convert and store the inputs and outputs as
Variables
. The order of insertion is crucial.Expanding the kernel input means registering an input and returning the element to be used for kernel expansion.
Here, "element" refers to either an
ExpandElement
orExpandElementTyped
, which are simply wrapper types forVariables
. -
Now that we have a fully initialized
KernelBuilder
and expanded kernel inputs/outputs, we proceed to actual kernel expansion.
In this phase, the body of the kernel function is expanded. In the gelu
example:
- Several important data structures are involved in this process:
Operation
: CubeCL operations that can be legally used in a GPU compute shader.Variable
: Holds data or CubeCL values that can be referenced during GPU compute shader operations.Scope
: A container that holds CubeCL operations and variables.CubeContext
: A wrapper type forScope
, containing root and non-root scopes and aVariablePool
.ExpandElement
: A wrapper type for CubeCLVariables
.ExpandElementTyped
: The typed version ofExpandElement
.
CubeCL operations behave like conventional operations, taking input operands and returning a result. This behavior is modeled in CubeCL IR.
#[cube(launch_unchecked)]
fn gelu_array<F: Float>(input: &Array<F>, output: &mut Array<F>) {
if ABSOLUTE_POS < input.len() {
output[ABSOLUTE_POS] = gelu_scalar::<F>(input[ABSOLUTE_POS]);
}
}
- In our
gelu
example, theif
condition:
ABSOLUTE_POS < input.len()
expands to:
/// Expanded Cube function
pub fn __expand<F: Float>(
context: &mut cubecl::frontend::CubeContext,
input: <Array<F> as cubecl::frontend::CubeType>::ExpandType,
output: <Array<F> as cubecl::frontend::CubeType>::ExpandType,
) -> () {
let _cond = {
let _lhs = ABSOLUTE_POS::expand(context);
let _rhs = input.clone().__expand_len_method(context);
cubecl::frontend::lt::expand(context, _lhs, _rhs)
};
...
...
...
ABSOLUTE_POS
(or_lhs
) is aVariable
.input.len()
(or_rhs
) is also aVariable
.- The less-than operator
(<)
expands into thelt::expand
operation, with_lhs
and_rhs
as inputs, along with thecontext
. - All operations (and their operands) are added to the provided context (
Scope
). - The order in which they are pushed onto a
CubeContext
(i.e., scope) is crucial.
Note:
_lhs
and_rhs
are actuallyExpandElementTyped<UInt>
s.
Once the kernel function is expanded, the next step is creating a kernel definition. The main data structures involved are:
KernelIntegrator
: Enables the creation of aKernelDefinition
based on aKernelExpansion
andKernelSettings
.KernelExpansion
: Contains the necessary information to generate aKernelDefinition
.KernelDefinition
: Represents the finalized kernel after expansion and integration, functioning as CubeCL's intermediate representation.
The first step is to instantiate a KernelIntegrator
by passing KernelSettings
and invoking the integrator’s integrate
method. This method combines the inputs and outputs (from the kernel expansion) into input/output bindings and returns a KernelDefinition
.
As mentioned earlier, a KernelDefinition
is the intermediate representation (IR) in CubeCL.
- The final step is to map this IR to the target compute shader source code. In our case, this is WGSL.
- Essentially, we map all variables and operations in CubeCL to the target shader source using the corresponding shader compiler—specifically, the
WgslCompiler
in our case.
In other words, the KernelDefinition
(IR) is mapped to the target compute shader source code, in this case, WGSL. The WgslCompiler
translates (or maps) each IR variable, operation, and input/output binding into its corresponding shader source equivalent.
- Example CubeCL Macro Expansion: https://gist.github.com/nihalpasham/133a935304e22054b0fe92efde43caec
- Example CubeCL IR: https://gist.github.com/nihalpasham/6e4c0edf5b1a0b199c05c186a5a75b2d
- Example CubeCL Generated WGSL Shader: https://gist.github.com/nihalpasham/0ed25f2dbcb08278f79d6ceabf38a60b
Created a video playlist for future reference: https://youtube.com/playlist?list=PLIUa1VcxJuwlI5sg8M8MH6FgzBzUWuAFI&si=ObxnUaUgUZxSebdq