Skip to content

Instantly share code, notes, and snippets.

@rom1504
Last active October 2, 2022 21:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rom1504/fb3bcafe91effef71b84a9da97cf824c to your computer and use it in GitHub Desktop.
Save rom1504/fb3bcafe91effef71b84a9da97cf824c to your computer and use it in GitHub Desktop.
jax gpu setup
python3.8 -m venv .env
source .env/bin/activate
pip install -U pip
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python
import jax
jax.default_backend()
jax.devices()
import jax.numpy as jnp
jax.lax.conv(jnp.eye(4)[None, None], jnp.eye(4)[None, None], (1, 1), 'SAME')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment