Skip to content

Instantly share code, notes, and snippets.

@mattjj
Last active March 18, 2024 08:14
Show Gist options
  • Star 25 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save mattjj/e8b51074fed081d765d2f3ff90edf0e9 to your computer and use it in GitHub Desktop.
Save mattjj/e8b51074fed081d765d2f3ff90edf0e9 to your computer and use it in GitHub Desktop.
import torch
import torch.utils.dlpack
import jax
import jax.dlpack
# A generic mechanism for turning a JAX function into a PyTorch function.
def j2t(x_jax):
x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax))
return x_torch
def t2j(x_torch):
x_torch = x_torch.contiguous() # https://github.com/google/jax/issues/8082
x_jax = jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x_torch))
return x_jax
def jax2torch(fun):
class JaxFun(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y_, ctx.fun_vjp = jax.vjp(fun, t2j(x))
return j2t(y_)
@staticmethod
def backward(ctx, grad_y):
grad_x_, = ctx.fun_vjp(t2j(grad_y))
return j2t(grad_x_),
return JaxFun.apply
# Here's a JAX function we want to interface with PyTorch code.
@jax.jit
def jax_square(x):
return x ** 2
torch_square = jax2torch(jax_square)
# Let's run it on Torch data!
import numpy as np
x = torch.from_numpy(np.array([1., 2., 3.], dtype='float32'))
y = torch_square(x)
print(y) # tensor([1., 4., 9.])
# And differentiate!
x = torch.tensor(np.array([1., 2., 3.], dtype='float32'), requires_grad=True)
y = torch.sum(torch_square(x))
y.backward()
print(x.grad) # tensor([2., 4., 6.])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment