Create a gist now

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import torch
@torch.jit.script
def foo(x, y):
for i in range(100):
if torch.fmod(_to_tensor(i), 3) == 0:
y += x
else:
x += y
return x, y
print(foo.__getattr__('forward').graph)
====
graph(%x.1 : Dynamic
%y.1 : Dynamic) {
%2 : int = prim::Constant[value=100]()
%3 : int = prim::Constant[value=1]()
%y : Dynamic, %x : Dynamic = prim::Loop(%2, %3, %y.1, %x.1)
block0(%i : int, %11 : Dynamic, %12 : Dynamic) {
%5 : Long() = prim::NumToTensor(%i)
%6 : int = prim::Constant[value=3]()
%7 : Dynamic = aten::fmod(%5, %6)
%8 : int = prim::Constant[value=0]()
%9 : Dynamic = aten::eq(%7, %8)
%10 : int = prim::TensorToNum(%9)
%x.3 : Dynamic, %y.3 : Dynamic = prim::If(%10)
block0() {
%13 : int = prim::Constant[value=1]()
%y.2 : Dynamic = aten::add(%11, %12, %13)
-> (%12, %y.2)
}
block1() {
%15 : int = prim::Constant[value=1]()
%x.2 : Dynamic = aten::add(%12, %11, %15)
-> (%x.2, %11)
}
%19 : int = prim::Constant[value=1]()
-> (%19, %y.3, %x.3)
}
return (%x, %y);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment