Skip to content

Instantly share code, notes, and snippets.

@jvmncs
Last active May 13, 2022 15:47
Show Gist options
  • Save jvmncs/b2d91fce5f12f01a8f908debbee5622c to your computer and use it in GitHub Desktop.
Save jvmncs/b2d91fce5f12f01a8f908debbee5622c to your computer and use it in GitHub Desktop.
eqx.filter_vmap failing to respect in_axes kwarg
import equinox as eqx
import jax
import numpy as np
def func(x):
return x + x
v_func_jax = jax.vmap(func, in_axes=0)
v_func_eqx = eqx.filter_vmap(func, in_axes=0)
x = np.array([[1, 2], [3, 4]], dtype=np.float64)
exp = np.array([[2, 4], [6, 8]], dtype=np.float64)
y_j = v_func_jax(x)
np.testing.assert_array_equal(y_j, exp)
y_e = v_func_eqx(x)
# Traceback (most recent call last):
# File ".../equinox_inaxes.py", line 18, in <module>
# y_e = v_func_eqx(x)
# File ".../site-packages/equinox/vmap_pmap.py", line 146, in __call__
# vmapd, nonvmapd = jax.vmap(
# TypeError: jax._src.api.vmap() got multiple values for keyword argument 'in_axes'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment