Skip to content

Instantly share code, notes, and snippets.

@sjdv1982
Last active June 8, 2023 12:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sjdv1982/803695055c78b62e5d5dc92a004efa77 to your computer and use it in GitHub Desktop.
Save sjdv1982/803695055c78b62e5d5dc92a004efa77 to your computer and use it in GitHub Desktop.
import numpy as np
import jax
from jax import numpy as jnp
def _run_argsort_numpy(args) -> jnp.ndarray:
jax.debug.print("Run argsort using Numpy")
arr, axis = args
return np.argsort(arr, axis=axis).astype(np.int32)
def run_argsort_numpy(arr:jnp.ndarray, axis=None) -> jnp.ndarray:
if jax.devices()[0].device_kind != "cpu":
return jnp.argsort(arr,axis=axis)
if axis is None:
result_shape = arr.ravel().shape
else:
result_shape = arr.shape
result_shape=jax.ShapeDtypeStruct(result_shape, np.int32)
return jax.pure_callback(
_run_argsort_numpy, result_shape, (arr, axis)
)
if jax.devices()[0].device_kind == "cpu":
run_argsort_numpy = jax.custom_jvp(run_argsort_numpy)
@run_argsort_numpy.defjvp
def default_grad(primals, tangents):
return run_argsort_numpy(*primals), run_argsort_numpy(*tangents)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment