Skip to content

Instantly share code, notes, and snippets.

@ramithuh
Created May 1, 2023 23:49
Show Gist options
  • Save ramithuh/87717e833687c8580e363c52bba17980 to your computer and use it in GitHub Desktop.
Save ramithuh/87717e833687c8580e363c52bba17980 to your computer and use it in GitHub Desktop.
specify jax default device to use
import jax
import jax.numpy as jnp
jax.config.update("jax_default_device", jax.devices("cpu")[0])
#runs on cpu
long_vector = jnp.arange(int(1e7))
%timeit jnp.dot(long_vector, long_vector).block_until_ready()
jax.config.update("jax_default_device", jax.devices("gpu")[2])
#runs on gpu
long_vector = jnp.arange(int(1e7))
%timeit jnp.dot(long_vector, long_vector).block_until_ready()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment