Skip to content

Instantly share code, notes, and snippets.

View ramithuh's full-sized avatar

Ramith Hettiarachchi ramithuh

View GitHub Profile
@ramithuh
ramithuh / jax_default_device.py
Created May 1, 2023 23:49
specify jax default device to use
import jax
import jax.numpy as jnp
jax.config.update("jax_default_device", jax.devices("cpu")[0])
#runs on cpu
long_vector = jnp.arange(int(1e7))
%timeit jnp.dot(long_vector, long_vector).block_until_ready()
jax.config.update("jax_default_device", jax.devices("gpu")[2])
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.