This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.flop_counter import FlopCounterMode | |
from triton.testing import do_bench | |
def get_flops_achieved(f): | |
flop_counter = FlopCounterMode(display=False) | |
with flop_counter: | |
f() | |
total_flops = flop_counter.get_total_flops() | |
ms_per_iter = do_bench(f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
def tree_stack(trees): | |
"""Takes a list of trees and stacks every corresponding leaf. | |
For example, given two trees ((a, b), c) and ((a', b'), c'), returns | |
((stack(a, a'), stack(b, b')), stack(c, c')). | |
Useful for turning a list of objects into something you can feed to a | |
vmapped function. |