Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Created October 8, 2023 09:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save llandsmeer/d11db31cbfe8a1cf6ddf44127aad8308 to your computer and use it in GitHub Desktop.
Save llandsmeer/d11db31cbfe8a1cf6ddf44127aad8308 to your computer and use it in GitHub Desktop.
Make jax.lax.cond work in jax2tf
import numpy as np
import onnxruntime as rt
def patch_jax2tf_with_case_instead_of_switch_case():
import tensorflow as tf
import jax
from jax.experimental.jax2tf import jax2tf
def _cond(index, *operands, branches, linear):
del linear
return tf.case([
(tf.equal(i, index), lambda: jax2tf._interpret_jaxpr(
jaxpr, *operands, extra_name_stack=f'branch_{i}_fun')
) for i, jaxpr in enumerate(branches)
], exclusive=True)
jax2tf.tf_impl[jax.lax.cond_p] = _cond
def build_onnx_model():
import equinox.internal as eqxi
from diffrax import diffeqsolve, ODETerm, Euler
def simulate(y0):
solution = diffeqsolve(
terms=ODETerm(lambda t, y, a: -y), solver=Euler(),
t0=0, t1=1, dt0=0.1, y0=y0
)
return solution.ys[0]
onnx_generator_fn = eqxi.to_onnx(simulate)
model, _none = onnx_generator_fn(1.0)
return model
patch_jax2tf_with_case_instead_of_switch_case()
onnx_model = build_onnx_model()
sess = rt.InferenceSession(onnx_model.SerializeToString())
input_name = sess.get_inputs()[0].name
onnx_output = sess.run(None, {input_name: np.array(100.0).astype('float32')})[0]
print(onnx_output)
assert np.isclose(onnx_output, 100 * np.exp(-1), rtol=0.1, atol=0.1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment