Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active April 11, 2024 17:17
Show Gist options
  • Save Chillee/07b36672a0ca2d1280e42b8d10f23174 to your computer and use it in GitHub Desktop.
Save Chillee/07b36672a0ca2d1280e42b8d10f23174 to your computer and use it in GitHub Desktop.
Compute Flop Utilization in PyTorch
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
iters_per_second = 1e3/ms_per_iter
print(f"{iters_per_second * total_flops / 1e12} TF/s")
from torchvision.models import resnet18
model = resnet18().cuda().half()
inp = torch.randn(128, 3, 224, 224, device='cuda', dtype=torch.half)
get_flops_achieved(lambda: model(inp).sum().backward())
compiled_model = torch.compile(model)
get_flops_achieved(lambda: compiled_model(inp).sum().backward())
@Chillee
Copy link
Author

Chillee commented Apr 10, 2024

as its docs are missing so we users don't know what is kosher to use it for and what not

I was gonna do it for PyTorch 2.3 release but I didn't end up getting around to it 😭

Yeah please do open an issue :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment