Skip to content

Instantly share code, notes, and snippets.

View Chillee's full-sized avatar

Horace He Chillee

View GitHub Profile
@Chillee
Chillee / flex_attention_tutorial.py
Last active July 10, 2024 03:39
flex_attention_tutorial.py
import torch
from torch.nn.attention._flex_attention import _create_block_mask, _create_mask
from functools import partial
from torch.nn.attention._flex_attention import _flex_attention
from triton.testing import do_bench
import torch.nn.functional as F
from functools import lru_cache
torch.set_default_device('cuda')
# Example usage
@Chillee
Chillee / assoc_scan.py
Last active May 31, 2024 21:52
Higher Order Kernel - associative scan
import torch
import torch.nn as nn
from torch._higher_order_ops.associative_scan import associative_scan
from triton.testing import do_bench
torch.set_default_device('cuda')
def combine_fn(i, j):
ia, ib = i
ja, jb = j
return ia * ja, ib * ja + jb
@Chillee
Chillee / mm_weird.py
Last active June 21, 2024 22:41
Strangely, Matrix Multiplications Run Faster When Given "Predictable" Data!
import torch
torch.set_default_device('cuda')
from triton.testing import do_bench
from collections import defaultdict
from functools import partial
import random
random.seed(0)
def get_flops(A, B):
ms = do_bench(lambda: torch.mm(A, B))
@Chillee
Chillee / attention_dim_bench.py
Created April 12, 2024 05:13
You Could Have Invented Flash-Attention!
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
torch.set_default_device('cuda')
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
@Chillee
Chillee / Q1.py
Last active April 8, 2024 04:07
What Shapes Do Matrix Multiplications Like?
import torch
from triton.testing import do_bench
torch.set_default_device('cuda')
for M, K, N in [(2047, 2048, 2048), (2048, 2047, 2048), (2048, 2048, 2047)]:
A = torch.randn(M, K, dtype=torch.bfloat16)
B = torch.randn(K, N, dtype=torch.bfloat16)
print(f"M={M}, K={K}, N={N}")
print(do_bench(lambda: torch.mm(A, B)))
@Chillee
Chillee / lora_example.py
Last active May 14, 2023 09:45
lora_example.py
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
from torch.utils._pytree import tree_map
class LoraTensor(object):
def __init__(self, weights, A, B):
self.weights = weights
self.A = A
self.B = B
@Chillee
Chillee / mfu_compute.py
Last active April 11, 2024 17:17
Compute Flop Utilization in PyTorch
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
@Chillee
Chillee / 1-pw_op_fusion.py
Last active July 7, 2024 04:12
PT 2.0 Benchmarks
import torch
import torch._inductor.config
import time
torch._inductor.config.triton.cudagraphs = False
torch.set_float32_matmul_precision('high')
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
for _ in range(warmup):
f()
@Chillee
Chillee / LCT.cpp
Last active October 26, 2019 16:54
Miscellaneous
/**
* Author:
* Description: link-cut Tree. Supports BST-like augmentations. (Can be used in place of HLD).
* Current implementation supports update value at a node, and query max on a path.
* For details about the structure, refer to https://en.wikipedia.org/wiki/Link/cut_tree
* Tested on: http://acm.timus.ru/problem.aspx?num=1553
* Status: Passes existing fuzz tests (with function names modified).
*/
struct Node {
bool flip = 0;
@Chillee
Chillee / factor.h
Last active April 30, 2019 09:25
Pollard Rho (Factoring algorithm)
ull f(ull x, ull n) { return (mod_mul(x, x, n) + 1) % n; }
ull Pollard(ull n) {
if (isPrime(n)) return n;
if (!(n & 1)) return 2;
for(int i = 1; i < 50; i++) {
ull x = i, y = f(x, n), p = __gcd(n + y - x, n);
while (p == 1)
x = f(x, n), y = f(f(y, n), n), p = __gcd(n + y - x, n);
if (p == n) continue;
return p;