Skip to content

Instantly share code, notes, and snippets.

@mattjj
Created October 26, 2021 07:19
Show Gist options
  • Save mattjj/473c5fc1f08ac704b26b6dce42a7682b to your computer and use it in GitHub Desktop.
Save mattjj/473c5fc1f08ac704b26b6dce42a7682b to your computer and use it in GitHub Desktop.
# referenced @chhillee's https://github.com/pytorch/functorch/blob/main/functorch/_src/nnc_compile.py
from typing import Callable, Dict, Any, List
from functools import partial
import numpy as np
import torch
import torch._C._te as te
from jax import core
from jax import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src.api_util import flatten_fun
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
def pytorch_jit(f: Callable):
def f_jit(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
flat_f, out_tree = flatten_fun(lu.wrap_init(f), in_tree)
out_flat = pytorch_jit_p.bind(flat_f, *args_flat)
return tree_unflatten(out_tree(), out_flat)
return f_jit
pytorch_jit_p = core.CallPrimitive('pytorch_jit')
@pytorch_jit_p.def_impl
def pytorch_jit_impl(f: lu.WrappedFun, *args):
# trace
in_avals = map(xla.abstractify, args)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals)
if consts: raise NotImplementedError
# compile
inputs = [parameter(i, a) for i, a in enumerate(in_avals)]
outputs, out_stmts = nnc_eval_jaxpr(jaxpr, inputs)
loopnest = te.LoopNest(te.Stmt(out_stmts), outputs)
loopnest.simplify()
loopnest.prepare_for_codegen()
stmt = te.simplify(loopnest.root_stmt())
cg = te.construct_codegen('llvm', stmt, [*inputs, *outputs])
# execute
ins = map(torch.from_numpy, map(np.array, args))
outs = map(empty_like, out_avals)
cg.call([*ins, *outs])
return map(np.asarray, outs)
def parameter(idx: int, aval: core.ShapedArray) -> te.ExprHandle:
return te.BufHandle(f'in_{idx}', shape_from_aval(aval), dtype_from_aval(aval))
def shape_from_aval(aval: core.ShapedArray) -> List[te.ExprHandle]:
return map(te.ExprHandle.int, aval.shape)
def dtype_from_aval(aval: core.ShapedArray) -> te.Dtype:
table = {'float32': te.Dtype.Float, 'int32': te.Dtype.Int,
'bool': te.Dtype.Bool}
return table[aval.dtype.name]
def empty_like(aval: core.ShapedArray) -> torch.Tensor:
table = {'float32': torch.float32, 'int32': torch.int32,
'bool': torch.bool}
return torch.empty(aval.shape, dtype=table[aval.dtype.name])
def literal(aval: core.ShapedArray, val: Any) -> te.ExprHandle:
if aval.dtype == np.dtype('float32'):
return te.ExprHandle.float(val)
elif aval.dtype == np.dtype('int32'):
return te.ExprHandle.int(val)
elif aval.dtype == np.dtype('bool'):
return te.ExprHandle.bool(val)
else:
raise NotImplementedError(f'literal: {val}:{aval}')
def nnc_eval_jaxpr(jaxpr: core.Jaxpr, args):
env: Dict[core.Var, te.ExprHandle] = {}
stmts: List[te.Stmt] = []
def read(x: core.Atom) -> te.ExprHandle:
if type(x) is core.Literal:
return literal(x.aval, x.val)
else:
return env[x]
def write(v: core.Var, expr: te.ExprHandle) -> None:
env[v] = expr
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
in_avals = [x.aval for x in eqn.invars]
out_avals = [v.aval for v in eqn.outvars]
in_exprs = map(read, eqn.invars)
rule = translations[eqn.primitive]
out_exprs, out_stmts = rule(in_avals, out_avals, in_exprs, **eqn.params)
stmts.extend(out_stmts)
map(write, eqn.outvars, out_exprs)
out_exprs = map(read, jaxpr.outvars)
return out_exprs, out_stmts
translations = {}
###
from jax._src.lax import lax
def standard_lowering(name: str):
name = f'aten::{name}'
def lower(in_avals, out_avals, in_exprs):
del in_avals
aval, = out_avals
out = te.lower(name, in_exprs, shape_from_aval(aval), dtype_from_aval(aval))
return [out.buf()], [out.stmt()]
return lower
translations[lax.sin_p] = standard_lowering('sin')
translations[lax.mul_p] = standard_lowering('mul')
translations[lax.cos_p] = standard_lowering('cos') # are these names right?
###
from jax import grad
import jax.numpy as jnp
x = jnp.array([1., 2., 3.])
y = pytorch_jit(jnp.sin)(x)
print(y)
print(jnp.sin(x))
x = jnp.array([1., 2., 3.])
y = pytorch_jit(lambda x: x * x)(x)
print(y)
print(x * x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment