Instantly share code, notes, and snippets.

# Peter Chng pchng

Last active April 28, 2024 18:57
Triton Puzzles: Simple FlashAttention: Using one program_id to block over q, and loop over k, v columns
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 # This is the tiling approach in: https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf def flashatt_spec(q: Float32[Tensor, "200"], k: Float32[Tensor, "200"], v: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]: x = q[:, None] * k[None, :] x_max = x.max(1, keepdim=True)[0] x = x - x_max x_exp = x.exp() soft = x_exp / x_exp.sum(1, keepdim=True) return (v[None, :] * soft).sum(1) @triton.jit
Created March 6, 2024 19:44
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 // Same as above, but row/col set to x/y instead. __global__ void matMulBad(float *left, float *right, float *out, int a, int b, int c) { int row = blockIdx.x * blockDim.x + threadIdx.x; int col = blockIdx.y * blockDim.y + threadIdx.y; if (row < a && col < c) { float sum = 0.0; for (int i = 0; i < b; i++) { // 1. If row/threadIdx.x is changing within the warp, then on each iteration the threads do a strided access: // They will access elements separated by a stride of b. This results in non-coalesced accesses (multiple memory reads)
Created March 6, 2024 19:41
CUDA: matmul good
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 // Computes out = left @ right, where `@` is matrix muliplication // Dimensions: // left: a x b // right: b x c // out: a x c // monolithic kernel: One thread per output element in matrix C. // (Each thread computes the dot product between a row in `left` and a col in `right`) __global__ void matMul(float *left, float *right, float *out, int a, int b, int c) { // Use y to index to rows, x to index to cols (just to match typical visualization) // row indexes into left, col indexes into right.
Created March 6, 2024 19:38
CUDA matrix multiplication: Warp-thread linearization test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 #include #define A 3000 #define B 4000 #define C 3000 // Computes out = left @ right, where `@` is matrix muliplication // Dimensions: // left: a x b // right: b x c
Last active September 8, 2018 21:14
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 # Fish prompt to show username, host, CWD, Python VirtualEnv, VCS (git, svn, hg) info. function fish_prompt --description 'Write out the prompt' set -l last_status \$status if not set -q __fish_git_prompt_show_informative_status set -g __fish_git_prompt_show_informative_status 1 end if not set -q __fish_git_prompt_hide_untrackedfiles set -g __fish_git_prompt_hide_untrackedfiles 1 end
Created February 6, 2016 00:29

### Keybase proof

I hereby claim:

• I am pchng on github.
• I am pchng (https://keybase.io/pchng) on keybase.
• I have a public key whose fingerprint is F722 6230 DB84 09B3 FFB2 737B B1EB 75BC 98B5 23A8

To claim this, I am signing this object:

Created March 9, 2015 02:56