Skip to content

Instantly share code, notes, and snippets.

@zou3519
Created September 19, 2018 20:23
Show Gist options
  • Save zou3519/cbd610472ba3369f556fa612a7d93b28 to your computer and use it in GitHub Desktop.
Save zou3519/cbd610472ba3369f556fa612a7d93b28 to your computer and use it in GitHub Desktop.
import torch
@torch.jit.script
def fn(x):
B = x.size(0)
C = x.size(1)
T = x.size(2)
return x.view(T, B, C)
x = torch.randn(3, 2, 1, dtype=torch.double, requires_grad=True)
out = fn(x)
fn.graph_for(x)
# Output:
# graph(%x : Double(*, *, *)) {
# %12 : int = prim::Constant[value=2]()
# %13 : int = prim::Constant[value=1]()
# %14 : int = prim::Constant[value=0]()
# %B : int = aten::size(%x, %14)
# %C : int = aten::size(%x, %13)
# %T : int = aten::size(%x, %12)
# %7 : int[] = prim::ListConstruct(%T, %B, %C)
# %8 : Dynamic = aten::view(%x, %7)
# return (%8);
# }
# I think we should be able to say the output of the view is a Double(*, *)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment