Skip to content

Instantly share code, notes, and snippets.

@redjade
Last active January 26, 2024 14:50
Show Gist options
  • Save redjade/b53869db8c40382d43f9a7328f321b64 to your computer and use it in GitHub Desktop.
Save redjade/b53869db8c40382d43f9a7328f321b64 to your computer and use it in GitHub Desktop.
jax setup and test
'''
install with gpu support
see https://github.com/google/jax#pip-installation
PYTHON_VERSION=cp37 # alternatives: cp35, cp36, cp37, cp38
CUDA_VERSION=cuda100 # alternatives: cuda90, cuda92, cuda100, cuda101 (check with nvcc --version)
PLATFORM=linux_x86_64 # alternatives: linux_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
# check lastest version of jaxlib on https://github.com/google/jax/releases
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.37-$PYTHON_VERSION-none-$PLATFORM.whl
pip install --upgrade jax # install jax
'''
import os
# don't preallocate GPU memory
# see https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# enforce CPU with GPU-enabled jaxlib
# XXX but process still opens minimum GPU memory footprint and not found how to turn it off
# ref: https://github.com/google/jax/blob/master/jax/lib/xla_bridge.py
os.environ['JAX_PLATFORM_NAME']='cpu'
import jax.numpy as np
import numpy as onp
from jax import grad, jit, vmap, random
x = np.ones((5000, 5000))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment