Skip to content

Instantly share code, notes, and snippets.

@Tanmaypatil123
Created November 25, 2025 09:19
Show Gist options
  • Select an option

  • Save Tanmaypatil123/e0ebdeb17f0b8086a7656acc8b3ddc04 to your computer and use it in GitHub Desktop.

Select an option

Save Tanmaypatil123/e0ebdeb17f0b8086a7656acc8b3ddc04 to your computer and use it in GitHub Desktop.
softmax kernel hip kittens
#include "kittens.cuh"
#include "pyutils/pyutils.cuh"
constexpr int B = 16;
constexpr int N = 4096;
constexpr int D = 2048;
#define NUM_WORKERS (4)
#define NUM_THREADS (NUM_WORKERS*kittens::WARP_THREADS)
using G = kittens::group<NUM_WORKERS>;
using namespace kittens;
template<int _N> struct softmax_globals {
// global descriptors
using x_gl = gl<bf16, -1, -1, -1, -1>;
using o_gl = gl<bf16, -1, -1, -1, -1>;
x_gl x; // Input
o_gl o; // Output
const int n_per_tile = 4;
const int n_tile_size = N / n_per_tile;
dim3 grid() { return dim3(n_tile_size, B, 1); }
dim3 block() { return dim3(NUM_THREADS); }
size_t dynamic_shared_memory() { return 0; }
};
template<int N_val>
__global__ void fused_softmax_tk(const softmax_globals<N_val> g) {
const int warpid = kittens::warpid();
const int batch = blockIdx.y;
const int seq_start = blockIdx.x * g.n_per_tile;
const int seq_idx = seq_start + warpid;
rv<bf16, D> x_reg;
load(x_reg, g.x, {0, batch, seq_idx, 0});
asm volatile("s_waitcnt vmcnt(0)");
bf16 max_val;
max(max_val, x_reg);
sub(x_reg, x_reg, max_val);
exp(x_reg, x_reg);
bf16 sum_exp;
sum(sum_exp, x_reg);
div(x_reg, x_reg, sum_exp);
store(g.o, x_reg, {0, batch, seq_idx, 0});
}
template<int N_val>
void dispatch_micro(softmax_globals<N_val> g) {
unsigned long mem_size = g.dynamic_shared_memory();
hipFuncSetAttribute(
(void*)fused_softmax_tk<N_val>,
hipFuncAttributeMaxDynamicSharedMemorySize,
mem_size
);
fused_softmax_tk<N_val><<<g.grid(), g.block(), mem_size>>>(g);
hipDeviceSynchronize();
}
PYBIND11_MODULE(tk_kernel, m) {
m.doc() = "1D Softmax kernel module";
py::bind_function<dispatch_micro<N>>(m, "dispatch_micro",
&softmax_globals<N>::x,
&softmax_globals<N>::o
);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment