Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Last active February 20, 2023 08:57
Show Gist options
  • Save leslie-fang-intel/60e4664eae40614ad0547aa7927b1be9 to your computer and use it in GitHub Desktop.
Save leslie-fang-intel/60e4664eae40614ad0547aa7927b1be9 to your computer and use it in GitHub Desktop.
import torch
import torch._dynamo as torchdynamo
import copy
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
if x.size(0) > 10:
return self.relu(x + x)
else:
return self.relu(x)
if __name__ == "__main__":
example_inputs = (torch.randn(14, 3, 224, 224),)
m = Mod().eval()
m(*example_inputs)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="symbolic",
)
print(guards)
print(m)
Print out
```
{Guard(name='self.relu', source=<GuardSource.LOCAL_NN_MODULE: 2>, create_fn=<function GuardBuilder.NN_MODULE at 0x7fdbee3d5160>, is_volatile=False, guard_types=None, code_list=None, obj_weakref=None, guarded_class_weakref=None), Guard(name='', source=<GuardSource.SHAPE_ENV: 6>, create_fn=<function GuardBuilder.SHAPE_ENV at 0x7fdbee3d5820>, is_volatile=False, guard_types=['SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV'], code_list=['x.size()[3] == x.size()[2]', 'x.stride()[0] == x.size()[2]**2*x.size()[1]', 'x.stride()[1] == x.size()[2]**2', 'x.stride()[2] == x.size()[2]', 'x.stride()[3] == 1', 'x.storage_offset() == 0', 'x.size()[0] > 10', 'x.size()[0] != 0 and x.size()[0] != 1', 'x.size()[1] != 0 and x.size()[1] != 1', 'x.size()[2] != 0 and x.size()[2] != 1'], obj_weakref=None, guarded_class_weakref=None), Guard(name='x', source=<GuardSource.LOCAL: 0>, create_fn=<function GuardBuilder.TENSOR_MATCH at 0x7fdbee3d58b0>, is_volatile=False, guard_types=['TENSOR_MATCH'], code_list=None, obj_weakref=<weakref at 0x7fdbed8a9bd0; to 'Tensor' at 0x7fdbeed63680>, guarded_class_weakref=<weakref at 0x7fdbf95874a0; to 'torch._C._TensorMeta' at 0x55eae9eb83a0 (Tensor)>), Guard(name='self', source=<GuardSource.LOCAL: 0>, create_fn=<function GuardBuilder.NN_MODULE at 0x7fdbee3d5160>, is_volatile=False, guard_types=['ID_MATCH'], code_list=['___check_obj_id(self, 140583230610208)'], obj_weakref=<weakref at 0x7fdbef2030e0; dead>, guarded_class_weakref=<weakref at 0x7fdc0e4d20e0; to 'type' at 0x55eaeb6878c0 (Mod)>)}
GraphModule()
def forward(self, orig_arg_0):
arg0, = fx_pytree.tree_flatten_spec(([orig_arg_0], {}), self._in_spec)
add_tensor = torch.ops.aten.add.Tensor(arg0, arg0); arg0 = None
relu_default = torch.ops.aten.relu.default(add_tensor); add_tensor = None
return pytree.tree_unflatten([relu_default], self._out_spec)
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment