-
-
Save Tanmaypatil123/e0ebdeb17f0b8086a7656acc8b3ddc04 to your computer and use it in GitHub Desktop.
softmax kernel hip kittens
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include "kittens.cuh" | |
| #include "pyutils/pyutils.cuh" | |
| constexpr int B = 16; | |
| constexpr int N = 4096; | |
| constexpr int D = 2048; | |
| #define NUM_WORKERS (4) | |
| #define NUM_THREADS (NUM_WORKERS*kittens::WARP_THREADS) | |
| using G = kittens::group<NUM_WORKERS>; | |
| using namespace kittens; | |
| template<int _N> struct softmax_globals { | |
| // global descriptors | |
| using x_gl = gl<bf16, -1, -1, -1, -1>; | |
| using o_gl = gl<bf16, -1, -1, -1, -1>; | |
| x_gl x; // Input | |
| o_gl o; // Output | |
| const int n_per_tile = 4; | |
| const int n_tile_size = N / n_per_tile; | |
| dim3 grid() { return dim3(n_tile_size, B, 1); } | |
| dim3 block() { return dim3(NUM_THREADS); } | |
| size_t dynamic_shared_memory() { return 0; } | |
| }; | |
| template<int N_val> | |
| __global__ void fused_softmax_tk(const softmax_globals<N_val> g) { | |
| const int warpid = kittens::warpid(); | |
| const int batch = blockIdx.y; | |
| const int seq_start = blockIdx.x * g.n_per_tile; | |
| const int seq_idx = seq_start + warpid; | |
| rv<bf16, D> x_reg; | |
| load(x_reg, g.x, {0, batch, seq_idx, 0}); | |
| asm volatile("s_waitcnt vmcnt(0)"); | |
| bf16 max_val; | |
| max(max_val, x_reg); | |
| sub(x_reg, x_reg, max_val); | |
| exp(x_reg, x_reg); | |
| bf16 sum_exp; | |
| sum(sum_exp, x_reg); | |
| div(x_reg, x_reg, sum_exp); | |
| store(g.o, x_reg, {0, batch, seq_idx, 0}); | |
| } | |
| template<int N_val> | |
| void dispatch_micro(softmax_globals<N_val> g) { | |
| unsigned long mem_size = g.dynamic_shared_memory(); | |
| hipFuncSetAttribute( | |
| (void*)fused_softmax_tk<N_val>, | |
| hipFuncAttributeMaxDynamicSharedMemorySize, | |
| mem_size | |
| ); | |
| fused_softmax_tk<N_val><<<g.grid(), g.block(), mem_size>>>(g); | |
| hipDeviceSynchronize(); | |
| } | |
| PYBIND11_MODULE(tk_kernel, m) { | |
| m.doc() = "1D Softmax kernel module"; | |
| py::bind_function<dispatch_micro<N>>(m, "dispatch_micro", | |
| &softmax_globals<N>::x, | |
| &softmax_globals<N>::o | |
| ); | |
| } | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment