Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save leslie-fang-intel/c8107b9af8b25d7887ea10a8738b2277 to your computer and use it in GitHub Desktop.
Save leslie-fang-intel/c8107b9af8b25d7887ea10a8738b2277 to your computer and use it in GitHub Desktop.
Dynamic Shape support
The running script as:
```
import torch
import torch.nn as nn
import torch._dynamo as torchdynamo
import copy
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(1024, 1000)
def forward(self, x):
x = torch.flatten(x, 1)
return self.linear(x)
if __name__ == "__main__":
x = torch.randn(16, 512, 2)
example_inputs = (x,)
model = Mod().eval()
ref_result = model(*example_inputs)
model, guards = torchdynamo.export(
model,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
print("model after export is: {}".format(model), flush=True)
model(torch.randn(2, 512, 2))
```
The model after `torchdynamo.export` is:
```
model after export is: GraphModule()
def forward(self, orig_arg_0):
arg0, = fx_pytree.tree_flatten_spec(([orig_arg_0], {}), self._in_spec)
view_default = torch.ops.aten.view.default(arg0, [16, 1024]); arg0 = None
_param_constant0 = self._param_constant0
t_default = torch.ops.aten.t.default(_param_constant0); _param_constant0 = None
_param_constant1 = self._param_constant1
addmm_default = torch.ops.aten.addmm.default(_param_constant1, view_default, t_default); _param_constant1 = view_default = t_default = None
return pytree.tree_unflatten([addmm_default], self._out_spec)
```
which decomps `flatten` into `torch.ops.aten.view.default(arg0, [16, 1024]);` with fixed size of `[16, 1024]`.
So the model after `torchdynamo.export` will fail to run with changed batch size.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment