Skip to content

Instantly share code, notes, and snippets.

@mattjj
Created April 18, 2019 05:02
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattjj/2ba580930472e8e04c1759737268af92 to your computer and use it in GitHub Desktop.
Save mattjj/2ba580930472e8e04c1759737268af92 to your computer and use it in GitHub Desktop.
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax import custom_transforms
from jax import core
from jax import grad
@custom_transforms
def f(x, y):
return x**2 + 3 * y
def defvjp_all(fun, vjp):
# TODO pytrees to jaxtupletrees
f_jvp = ad.primitive_jvps[fun.primitive]
fprime_p = core.Primitive('f_jvp')
fprime_p.def_abstract_eval(lambda xs, ts: fun.primitive.abstract_eval(*xs))
fprime_p.def_impl(lambda xs, ts: f_jvp(tuple(xs), tuple(ts))[1])
def dummy_jvp(xs, ts):
out_primal = fun(*xs)
instantiated_ts = map(ad.instantiate_zeros, xs, ts)
out_tangent = fprime_p.bind(core.pack(xs), core.pack(instantiated_ts))
return out_primal, out_tangent
ad.primitive_jvps[fun.primitive] = dummy_jvp
ad.primitive_transposes[fprime_p] = lambda ct, xs, _: [None, core.pack(vjp(ct, *xs))]
def custom_vjp(g, x, y):
return [99., 22.]
defvjp_all(f, custom_vjp)
print grad(f)(3., 4.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment