Skip to content

Instantly share code, notes, and snippets.

@mattjj
Created March 29, 2020 16:45
Show Gist options
  • Save mattjj/bc31f76f9e97f6de03114ea10cb853f7 to your computer and use it in GitHub Desktop.
Save mattjj/bc31f76f9e97f6de03114ea10cb853f7 to your computer and use it in GitHub Desktop.
from jax import core
# A primitive is just a name to which we associate rules.
sincos_p = core.Primitive('sincos')
# A primitive's "bind" is how it gets applied, in a way that interacts with the
# trace/transform machinery. As a convention we wrap them in Python functions
# like this:
def sincos(x):
return sincos_p.bind(x)
# We can't do anything before we attach rules. Even evaluation is a rule. Here's
# how we attach an evaluation rule.
import numpy as onp
def sincos_impl(x):
return onp.sin(onp.cos(x))
sincos_p.def_impl(sincos_impl)
# Now we can evaluate it:
print(sincos(3.)) # -0.8360218615377305
# For making jaxprs and jit compilation (and a few other transforms) we need an
# abstract evaluation rule. An abstract evaluation rule must return an upper
# bound on the abstract value lattice for the output given the input. Here's a
# verbose way of doing it:
from jax.core import UnshapedArray, ShapedArray, ConcreteArray
def sincos_abstract_eval(x):
if not onp.issubdtype(x.dtype, onp.floating):
raise TypeError("must be floating dtype")
if isinstance(x, ConcreteArray):
return ConcreteArray(sincos_impl(x.val))
elif isinstance(x, ShapedArray):
return ShapedArray(x.shape, x.dtype)
elif isinstance(x, UnshapedArray):
return UnshapedArray(x.dtype)
else:
raise TypeError(x)
sincos_p.def_abstract_eval(sincos_abstract_eval)
# But here's a quicker way that will work just fine.
from jax.core import raise_to_shaped
def sincos_abstract_eval(x):
if not onp.issubdtype(x.dtype, onp.floating):
raise TypeError("must be floating dtype")
return raise_to_shaped(x)
sincos_p.def_abstract_eval(sincos_abstract_eval)
# Now we can make jaxprs:
from jax import make_jaxpr
print(make_jaxpr(sincos)(3.))
# { lambda ; a.
# let b = sincos a
# in (b,) }
# For jit compilation we also need an XLA translation rule.
from jax.interpreters import xla
def sincos_translation_rule(c, x):
# c is an XLA ComputationBuilder, x is an XlaOp representing the input
return c.Sin(c.Cos(x))
xla.translations[sincos_p] = sincos_translation_rule
# Now we can jit:
from jax import jit
a = jit(lambda x: sincos(sincos(x)))(3.)
b = sincos(sincos(3.))
print(a) # 0.62131506
print(b) # 0.6213150041315272 <-- numpy impl has more bits
# A trick we can pull is to generate an impl from the translation rule,
# basically meaning "when you want to evaluate, just jit compile the translation
# rule by itself." Then we don't need an onp-based impl. Here's how that looks:
from functools import partial
sincos_p.def_impl(partial(xla.apply_primitive, sincos_p))
print(sincos(3.)) # -0.83602184
print(sincos(sincos(3.))) # 0.62131506
# Finally, differentiation rules! Here's a forward-mode rule:
from jax.interpreters import ad
from jax import lax # most of our primitives live here
def sincos_jvp_rule(primals, tangents):
x, = primals
t, = tangents
out_primal = sincos(x)
out_tangent = t * (-lax.sin(x)) * lax.cos(lax.cos(x))
return out_primal, out_tangent
ad.primitive_jvps[sincos_p] = sincos_jvp_rule
# Now we can use forawrd-mode autodiff:
from jax import jvp
y, y_dot = jvp(sincos, (3.,), (1.,))
print(y) # -0.83602184
print(y_dot) # -0.07743199
y, y_dot = jvp(lambda x: lax.sin(lax.cos(x)), (3.,), (1.,))
print(y) # -0.83602184
print(y_dot) # -0.07743199
# We can use reverse-mode too, since JAX does automatic transposition on our
# foward-mode rule to generate reverse mode:
from jax import grad
print(grad(sincos)(3.)) # -0.07743199
# Custom reverse-mode rules are a bit trickier, since JAX doesn't implement
# reverse-mode directly. We don't do this for any JAX primitives in the core .
# Moreover, we can't set up a custom VJP rule *and* still keep forward-mode
# working.
# The mechanism is being replaced, but here's the not-quite-yet-deprecated way
# of doing it (see https://github.com/google/jax/pull/636):
ad.defvjp2(sincos_p, lambda g, ans, x: g * lax.cos(lax.cos(x)) * (-lax.sin(x)))
print(grad(sincos)(3.)) # -0.07743199
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment