Skip to content

Instantly share code, notes, and snippets.

@thecharlieblake
Last active October 13, 2023 11:39
Show Gist options
  • Save thecharlieblake/82f1b54bbf608d8d339043ed8852cf91 to your computer and use it in GitHub Desktop.
Save thecharlieblake/82f1b54bbf608d8d339043ed8852cf91 to your computer and use it in GitHub Desktop.
Given a numpy function, prints equivalent PyTorch code (as canonical ATen ops) and returns it as a new function.
from typing import Callable, List
import numpy as np
import torch
from torch._dynamo.backends.common import aot_autograd
from torch.fx.graph_module import GraphModule
# NOTE: requires torch >= 2.1.0
def np2torch(fn: Callable) -> Callable:
"""
Given a numpy function, prints equivalent PyTorch code
(as canonical ATen ops) and returns it as a new function.
"""
def aot_compile_backend(gm: GraphModule, _) -> Callable:
print(gm.code)
return gm
torch._dynamo.reset()
compile_backend = aot_autograd(fw_compiler=aot_compile_backend)
return torch.compile(fn, backend=compile_backend)
def example_fn(a, b):
c = a + b
d = np.tan(np.matmul(a, b))
e = c - d
return np.sum(e, axis=-1)
a, b = np.random.randn(2**10, 2**10), np.random.randn(2**10, 2**10)
print("Numpy:", example_fn(a, b))
torch_fn = np2torch(example_fn)
a, b = torch.from_numpy(a), torch.from_numpy(b)
print("Torch:", torch_fn(a, b))
@thecharlieblake
Copy link
Author

thecharlieblake commented Oct 13, 2023

Output:

Numpy: [-5.46059881e+04 -5.20969791e+02  4.14132323e+02 ... -2.61937142e+02
 -3.53576991e+01  1.46732928e+02]

def forward(self, arg0_1, arg1_1):
    add = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
    mm = torch.ops.aten.mm.default(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    tan = torch.ops.aten.tan.default(mm);  mm = None
    sub = torch.ops.aten.sub.Tensor(add, tan);  add = tan = None
    sum_1 = torch.ops.aten.sum.dim_IntList(sub, [1]);  sub = None
    return (sum_1,)
    
Torch: [-5.46059881e+04 -5.20969791e+02  4.14132323e+02 ... -2.61937142e+02
 -3.53576991e+01  1.46732928e+02]

@thecharlieblake
Copy link
Author

thecharlieblake commented Oct 13, 2023

Just to prove this is indeed valid PyTorch, if you then run forward(None, a, b) you get:

(tensor([-5.4606e+04, -5.2097e+02,  4.1413e+02,  ..., -2.6194e+02,
         -3.5358e+01,  1.4673e+02], dtype=torch.float64),)

Magic!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment