Skip to content

Instantly share code, notes, and snippets.

@pierrelux
Last active August 13, 2019 22:41
Show Gist options
  • Save pierrelux/059bfef496354cd8fa0ff4557db3b58b to your computer and use it in GitHub Desktop.
Save pierrelux/059bfef496354cd8fa0ff4557db3b58b to your computer and use it in GitHub Desktop.
from jax import jvp, grad
def f(x,y):
return x + y**2
def freeze(f, argnum, val):
def _f(arg):
args = [val, arg] if argnum == 0 else [arg, val]
return f(*args)
return _f
def mixed_jvp(f, order, primals, tangents):
frozen_func = freeze(grad(f, order[0]), argnum=order[1], val=primals[order[0]])
return jvp(frozen_func, (primals[order[1]],), tangents)
mixed_jvp(f, order=(0,1), primals=(2., 3.), tangents=(1.,))
@gehring
Copy link

gehring commented Aug 13, 2019

def make_mixed_jvp(f):
    grad_fun = jax.grad(f, 0)
    def mixed_jvp(x, y):
        return lambda tangents: jax.jvp(lambda y: grad_fun(x, y), y, tangents)
    return mixed_jvp

@gehring
Copy link

gehring commented Aug 13, 2019

def make_mixed_jvp(f, yx=False):
    func = f
    if yx:
        func = lambda x, y: f(y, x)
    grad_fun = jax.grad(func, 0)
    def mixed_jvp(x, y):
        return lambda tangents: jax.jvp(lambda y: grad_fun(x, y), y, tangents)
    return mixed_jvp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment