Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created July 13, 2022 04:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ezyang/df2d746cac3b2c7d55c181e37c57ef23 to your computer and use it in GitHub Desktop.
Save ezyang/df2d746cac3b2c7d55c181e37c57ef23 to your computer and use it in GitHub Desktop.
class ApplyCudaGraphs(torch.fx.Interpreter):
def call_module(self, target, args, kwargs):
assert not kwargs
submod = self.module.get_submodule(target)
self.module.delete_submodule(target)
mutated_inputs = FindInputMutations(submod)(*args)
self.module.add_submodule(target, CudaGraphModule(submod, mutated_inputs))
r = super().call_module(target, args, kwargs)
return r
def run(self, *args):
with FakeTensorMode.push() as mode:
return super().run(*map(mode.from_tensor, args))
class FindInputMutations(torch.fx.Interpreter):
def __init__(self, gm, **kwargs):
super().__init__(gm, **kwargs)
self.inputs = defaultdict(set)
self.input_idx = 0
self.mutated_inputs = set()
def placeholder(self, target, args, kwargs):
r = super().placeholder(target, args, kwargs)
self.inputs[StorageWeakRef(r.storage())].add(self.input_idx)
self.input_idx += 1
return r
def call_function(self, target, args, kwargs):
r = super().call_function(target, args, kwargs)
schema = target._schema
for i, arg in enumerate(schema.arguments):
if i < len(args):
argument = args[i]
else:
if arg.name not in kwargs:
continue
argument = kwargs[arg.name]
mut_arg = False
if arg.alias_info:
if arg.alias_info.is_write:
mut_arg = True
if mut_arg:
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= self.inputs[StorageWeakRef(r.storage())]
# TODO: error on unrecognized nodes
return r
def __call__(self, *args):
self.run(*args)
return self.mutated_inputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment