Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active April 11, 2024 17:17
Show Gist options
  • Star 11 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save Chillee/07b36672a0ca2d1280e42b8d10f23174 to your computer and use it in GitHub Desktop.
Save Chillee/07b36672a0ca2d1280e42b8d10f23174 to your computer and use it in GitHub Desktop.
Compute Flop Utilization in PyTorch
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
iters_per_second = 1e3/ms_per_iter
print(f"{iters_per_second * total_flops / 1e12} TF/s")
from torchvision.models import resnet18
model = resnet18().cuda().half()
inp = torch.randn(128, 3, 224, 224, device='cuda', dtype=torch.half)
get_flops_achieved(lambda: model(inp).sum().backward())
compiled_model = torch.compile(model)
get_flops_achieved(lambda: compiled_model(inp).sum().backward())
@152334H
Copy link

152334H commented Mar 30, 2024

this produces the following error on torch 2.2.1:

  File "/home/a/mfu_compute.py", line 20, in <lambda>
    get_flops_achieved(lambda: compiled_model(inp).sum().backward())
  File "/home/a/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/a/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/a/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/a/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/a/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/a/venv/lib/python3.10/site-packages/torchvision/models/resnet.py", line 284, in forward
    def forward(self, x: Tensor) -> Tensor:
  File "/home/a/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/a/venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/a/venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 901, in forward
    return compiled_fn(full_args)
  File "/home/a/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/home/a/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 83, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/home/a/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/a/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/home/a/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/a/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 408, in forward
    fw_outs = call_func_at_runtime_with_args(
  File "/home/a/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/a/venv/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 864, in __call__
    return self.get_current_callable()(inputs)
  File "/home/a/venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 611, in run
    return model(new_inputs)
  File "/home/a/venv/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 892, in _run_from_cache
    return compiled_graph.compiled_artifact(inputs)
  File "/tmp/torchinductor_a/lt/cltqowfhe2k3kwflralec2ug4rndsitzkngyxaato6ihdivt3r4f.py", line 2867, in call
    extern_kernels.addmm(primals_62, buf246, reinterpret_tensor(primals_61, (512, 1000), (1, 512), 0), alpha=1, beta=1, out=buf247)
  File "/home/a/venv/lib/python3.10/site-packages/torch/utils/flop_counter.py", line 455, in __torch_dispatch__
    flop_count = flop_count_func(*args, **kwargs, out=out)  # type: ignore[operator]
TypeError: torch.utils.flop_counter.addmm_flop() got multiple values for keyword argument 'out'

i believe this occurs in FlopCounterMode.torch_dispatch because a function call that happens to include an out=... parameter in its kwargs will conflict with the different out=out argument for flop_count_func():

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs if kwargs else {}
        out = func(*args, **kwargs) # <-- arbitrary func call could contain 'out' in kwargs
        func_packet = func._overloadpacket
        if func_packet in self.flop_registry:
            flop_count_func = self.flop_registry[func_packet]
            flop_count = flop_count_func(*args, **kwargs, out=out)  # will conflict here

i don't believe the out= parameter is important in any of the @register_flop_formula decorated functions, so maybe it could simply be discarded from the kwargs?

@stas00
Copy link

stas00 commented Apr 10, 2024

I have just tried this for the first time and got the same error on the torch.compile segment.

pt=2.2.1+cu121, cuda=12.1,

@Chillee

@Chillee
Copy link
Author

Chillee commented Apr 10, 2024

Actually two issues here in some sense.

  1. the flop counter actually isn't really intended to work when running a torch.compile'd model :P This is mostly "fine" because the flops of a compiled model is generally identical to the flops of an uncompiled model. But the interposition points we use (i.e. the dispatcher) are not guaranteed to trigger when running under the compiler (for example, if you're using cudagraphs).
  2. The actual error that you're running into, which is fixed by pytorch/pytorch#123768

EDIT: To clarify, the intended usage of FlopCounterMode with a compiled model is to run it on an uncompiled version of the model.

@stas00
Copy link

stas00 commented Apr 10, 2024

super! Thank you, Horace

Do you want me to open an Issue to document FlopCounterMode? as its docs are missing so we users don't know what is kosher to use it for and what not

@Chillee
Copy link
Author

Chillee commented Apr 10, 2024

as its docs are missing so we users don't know what is kosher to use it for and what not

I was gonna do it for PyTorch 2.3 release but I didn't end up getting around to it 😭

Yeah please do open an issue :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment