Skip to content

Instantly share code, notes, and snippets.

@bwasti
Created August 24, 2023 17:50
Show Gist options
  • Save bwasti/7e4cb9bd1aaddeb09bd360b570a486b1 to your computer and use it in GitHub Desktop.
Save bwasti/7e4cb9bd1aaddeb09bd360b570a486b1 to your computer and use it in GitHub Desktop.
import torch
def cudagraph(f):
_graphs = {}
def f_(*args):
key = hash(tuple(tuple(a.shape) for a in args))
if key in _graphs:
wrapped, *_ = _graphs[key]
return wrapped(*args)
g = torch.cuda.CUDAGraph()
in_tensors = [a.clone() for a in args]
f(*in_tensors) # stream warmup
with torch.cuda.graph(g):
out_tensors = f(*in_tensors)
def wrapped(*args):
[a.copy_(b) for a, b in zip(in_tensors, args)]
g.replay()
return [o.clone() for o in out_tensors]
_graphs[key] = (wrapped, g, in_tensors, out_tensors)
return wrapped(*args)
return f_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment