Skip to content

Instantly share code, notes, and snippets.

View Chillee's full-sized avatar

Horace He Chillee

View GitHub Profile
@Chillee
Chillee / attention_dim_bench.py
Created April 12, 2024 05:13
You Could Have Invented Flash-Attention!
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
torch.set_default_device('cuda')
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
@Chillee
Chillee / Q1.py
Last active April 8, 2024 04:07
What Shapes Do Matrix Multiplications Like?
import torch
from triton.testing import do_bench
torch.set_default_device('cuda')
for M, K, N in [(2047, 2048, 2048), (2048, 2047, 2048), (2048, 2048, 2047)]:
A = torch.randn(M, K, dtype=torch.bfloat16)
B = torch.randn(K, N, dtype=torch.bfloat16)
print(f"M={M}, K={K}, N={N}")
print(do_bench(lambda: torch.mm(A, B)))
@Chillee
Chillee / lora_example.py
Last active May 14, 2023 09:45
lora_example.py
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
from torch.utils._pytree import tree_map
class LoraTensor(object):
def __init__(self, weights, A, B):
self.weights = weights
self.A = A
self.B = B
@Chillee
Chillee / mfu_compute.py
Last active April 11, 2024 17:17
Compute Flop Utilization in PyTorch
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
@Chillee
Chillee / 1-pw_op_fusion.py
Last active February 26, 2024 20:45
PT 2.0 Benchmarks
import torch
import torch._inductor.config
import time
torch._inductor.config.triton.cudagraphs = False
torch.set_float32_matmul_precision('high')
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
for _ in range(warmup):
f()
@Chillee
Chillee / LCT.cpp
Last active October 26, 2019 16:54
Miscellaneous
/**
* Author:
* Description: link-cut Tree. Supports BST-like augmentations. (Can be used in place of HLD).
* Current implementation supports update value at a node, and query max on a path.
* For details about the structure, refer to https://en.wikipedia.org/wiki/Link/cut_tree
* Tested on: http://acm.timus.ru/problem.aspx?num=1553
* Status: Passes existing fuzz tests (with function names modified).
*/
struct Node {
bool flip = 0;
@Chillee
Chillee / factor.h
Last active April 30, 2019 09:25
Pollard Rho (Factoring algorithm)
ull f(ull x, ull n) { return (mod_mul(x, x, n) + 1) % n; }
ull Pollard(ull n) {
if (isPrime(n)) return n;
if (!(n & 1)) return 2;
for(int i = 1; i < 50; i++) {
ull x = i, y = f(x, n), p = __gcd(n + y - x, n);
while (p == 1)
x = f(x, n), y = f(f(y, n), n), p = __gcd(n + y - x, n);
if (p == n) continue;
return p;
@Chillee
Chillee / fft.cpp
Created April 23, 2019 22:42
Educational implementations
typedef complex<double> cpx;
typedef vector<cpx> Poly;
typedef vector<cpx> Eval;
Eval FFT(Poly P) {
int n = P.size();
if (n == 1)
return P;
Poly P_even(n / 2), P_odd(n / 2);
for (int j = 0; j < n / 2; j++) {
P_even[j] = P[j * 2]; // Put all the even terms (2*0, 2*1, 2*2,...) in one polynomial
@Chillee
Chillee / crt.cpp
Created April 13, 2019 04:18
Chinese Remainder Theorem
ll chinese(ll a, ll m, ll b, ll n) { //x = a %m, x = b%n, gcd(m,n)=1
ll x, y;
euclid(m, n, x, y);
ll ret = a * (y + m) % m * n + b * (x + n) % n * m;
if (ret >= m * n)
ret -= m * n;
return ret;
}
ll chinese_common(ll a, ll m, ll b, ll n) { // gcd(m,n) != 1
ll d = __gcd(m, n);
@Chillee
Chillee / BerlekampMassey.cpp
Created April 4, 2019 02:27
Linear Recurrence (Berlekamp Massey, K-th term)
vector<ll> BerlekampMassey(vector<ll> s) {
int n = s.size(), L = 0, m = 0;
vector<ll> C(n), B(n), T;
C[0] = B[0] = 1;
ll b = 1;
for (int i = 0; i < n; i++) {
m++;
ll d = s[i] % MOD;
for (int j = 1; j < L + 1; j++)