Skip to content

Instantly share code, notes, and snippets.

View KeAWang's full-sized avatar

Alex Wang KeAWang

View GitHub Profile
@KeAWang
KeAWang / tree_stack.py
Last active June 5, 2023 16:05 — forked from willwhitney/tree_stack.py
utils for stacking and unstacking jax pytrees to deal with vmap
import jax
import jax.numpy as jnp
def tree_stack(trees):
"""Takes a list of trees and stacks every corresponding leaf.
For example, given two trees ((a, b), c) and ((a', b'), c'), returns
((stack(a, a'), stack(b, b')), stack(c, c')).
Useful for turning a list of objects into something you can feed to a
vmapped function.
@KeAWang
KeAWang / mfu_compute.py
Created April 11, 2024 17:17 — forked from Chillee/mfu_compute.py
Compute Flop Utilization in PyTorch
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)