This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.flop_counter import FlopCounterMode | |
from triton.testing import do_bench | |
torch.set_default_device('cuda') | |
def get_flops_achieved(f): | |
flop_counter = FlopCounterMode(display=False) | |
with flop_counter: | |
f() | |
total_flops = flop_counter.get_total_flops() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from triton.testing import do_bench | |
torch.set_default_device('cuda') | |
for M, K, N in [(2047, 2048, 2048), (2048, 2047, 2048), (2048, 2048, 2047)]: | |
A = torch.randn(M, K, dtype=torch.bfloat16) | |
B = torch.randn(K, N, dtype=torch.bfloat16) | |
print(f"M={M}, K={K}, N={N}") | |
print(do_bench(lambda: torch.mm(A, B))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.utils.parametrize as parametrize | |
from torch.utils._pytree import tree_map | |
class LoraTensor(object): | |
def __init__(self, weights, A, B): | |
self.weights = weights | |
self.A = A | |
self.B = B |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.flop_counter import FlopCounterMode | |
from triton.testing import do_bench | |
def get_flops_achieved(f): | |
flop_counter = FlopCounterMode(display=False) | |
with flop_counter: | |
f() | |
total_flops = flop_counter.get_total_flops() | |
ms_per_iter = do_bench(f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch._inductor.config | |
import time | |
torch._inductor.config.triton.cudagraphs = False | |
torch.set_float32_matmul_precision('high') | |
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False): | |
for _ in range(warmup): | |
f() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Author: | |
* Description: link-cut Tree. Supports BST-like augmentations. (Can be used in place of HLD). | |
* Current implementation supports update value at a node, and query max on a path. | |
* For details about the structure, refer to https://en.wikipedia.org/wiki/Link/cut_tree | |
* Tested on: http://acm.timus.ru/problem.aspx?num=1553 | |
* Status: Passes existing fuzz tests (with function names modified). | |
*/ | |
struct Node { | |
bool flip = 0; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ull f(ull x, ull n) { return (mod_mul(x, x, n) + 1) % n; } | |
ull Pollard(ull n) { | |
if (isPrime(n)) return n; | |
if (!(n & 1)) return 2; | |
for(int i = 1; i < 50; i++) { | |
ull x = i, y = f(x, n), p = __gcd(n + y - x, n); | |
while (p == 1) | |
x = f(x, n), y = f(f(y, n), n), p = __gcd(n + y - x, n); | |
if (p == n) continue; | |
return p; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
typedef complex<double> cpx; | |
typedef vector<cpx> Poly; | |
typedef vector<cpx> Eval; | |
Eval FFT(Poly P) { | |
int n = P.size(); | |
if (n == 1) | |
return P; | |
Poly P_even(n / 2), P_odd(n / 2); | |
for (int j = 0; j < n / 2; j++) { | |
P_even[j] = P[j * 2]; // Put all the even terms (2*0, 2*1, 2*2,...) in one polynomial |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ll chinese(ll a, ll m, ll b, ll n) { //x = a %m, x = b%n, gcd(m,n)=1 | |
ll x, y; | |
euclid(m, n, x, y); | |
ll ret = a * (y + m) % m * n + b * (x + n) % n * m; | |
if (ret >= m * n) | |
ret -= m * n; | |
return ret; | |
} | |
ll chinese_common(ll a, ll m, ll b, ll n) { // gcd(m,n) != 1 | |
ll d = __gcd(m, n); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
vector<ll> BerlekampMassey(vector<ll> s) { | |
int n = s.size(), L = 0, m = 0; | |
vector<ll> C(n), B(n), T; | |
C[0] = B[0] = 1; | |
ll b = 1; | |
for (int i = 0; i < n; i++) { | |
m++; | |
ll d = s[i] % MOD; | |
for (int j = 1; j < L + 1; j++) |
NewerOlder