Skip to content

Instantly share code, notes, and snippets.

@mattjj
Created March 12, 2020 17:23
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mattjj/561dc690ef3cec9111f699c82ce2082b to your computer and use it in GitHub Desktop.
Save mattjj/561dc690ef3cec9111f699c82ce2082b to your computer and use it in GitHub Desktop.
from functools import partial
from jax import core
from jax.util import safe_map, safe_zip
import jax.linear_util as lu
map = safe_map
zip = safe_zip
@lu.transformation
def _rewrite(rules, args):
with core.new_master(RewriteTrace) as master:
master.rules = rules
trace = RewriteTrace(master, core.cur_sublevel())
in_tracers = map(partial(RewriteTracer, trace), args)
out_tracers = yield in_tracers, {}
outs = [trace.full_raise(t).val for t in out_tracers]
del master, out_tracers
yield outs
class RewriteTracer(core.Tracer):
__slots__ = ["_trace", "val"]
def __init__(self, trace, val):
self._trace = trace
self.val = val
@property
def aval(self):
return core.get_aval(self.val)
def full_lower(self):
return self
class RewriteTrace(core.Trace):
def pure(self, val):
return RewriteTracer(self, val)
def lift(self, val):
return RewriteTracer(self, val)
def sublift(self, val):
return RewriteTracer(self, val.val)
def process_primitive(self, primitive, tracers, params):
vals_in = [t.val for t in tracers]
if primitive in self.master.rules:
vals_out = rules[primitive](*vals_in, **params)
else:
vals_out = primitive.bind(*vals_in, **params)
if primitive.multiple_results:
return map(partial(RewriteTracer, self), vals_out)
else:
return RewriteTracer(self, vals_out)
def process_call(self, call_primitive, f, tracers, params):
assert False # TODO
def process_map(self, map_primitive, f, tracers, params):
assert False # TODO
### api.py
from jax.api_util import flatten_fun_nokwargs
from jax.tree_util import tree_flatten, tree_unflatten
def rewrite(fun, *args, rules):
args_flat, in_tree = tree_flatten(args)
f = lu.wrap_init(fun)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
out_flat = _rewrite(flat_fun, rules).call_wrapped(args_flat)
return tree_unflatten(out_tree(), out_flat)
@mattjj
Copy link
Author

mattjj commented Mar 12, 2020

Use example:

from jax.lax.lax import mul_p

rules = {
    mul_p : lambda x, y: 3 * y,
}

out = rewrite(lambda x: 2 * x, 2, rules=rules)
print(out)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment