Skip to content

Instantly share code, notes, and snippets.

@zhuangh
Last active November 19, 2023 06:46
Show Gist options
  • Save zhuangh/b1aa958a7dd4671bb76a68695fe970ba to your computer and use it in GitHub Desktop.
Save zhuangh/b1aa958a7dd4671bb76a68695fe970ba to your computer and use it in GitHub Desktop.
cudagraph_decorator.py
import torch
# acknowledgement: https://gist.github.com/bwasti/7e4cb9bd1aaddeb09bd360b570a486b1
def cudagraph(f):
_graphs = {}
def f_(*args):
key = hash(tuple(tuple(a.shape) for a in args))
if key in _graphs:
wrapped, *_ = _graphs[key]
return wrapped(*args)
g = torch.cuda.CUDAGraph()
in_tensors = [a.clone() for a in args]
f(*in_tensors) # stream warmup
with torch.cuda.graph(g):
out_tensors = f(*in_tensors)
def wrapped(*args):
[a.copy_(b) for a, b in zip(in_tensors, args)]
g.replay()
return [o.clone() for o in out_tensors]
_graphs[key] = (wrapped, g, in_tensors, out_tensors)
return wrapped(*args)
return f_
A = torch.randn(128, 128).cuda()
B = torch.randn(128, 128).cuda()
def foo(a, b):
c = a
for i in range(10):
c = a * b + c
return [c]
@cudagraph
def bar(a, b):
c = a
for i in range(10):
c = a * b + c
return [c]
import timeit
ta = timeit.timeit("foo(A, B)", globals=globals(), number=10)
tb = timeit.timeit("bar(A, B)", globals=globals(), number=10)
print(ta, tb)
# GTX 1060 6GB
# runtime 0.03496104599980754
# runtime 0.09703805100070895
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment