Skip to content

Instantly share code, notes, and snippets.

@jvmncs
Created January 25, 2024 19:44
Show Gist options
  • Save jvmncs/02d6364cafb5e95475c927089822d5c4 to your computer and use it in GitHub Desktop.
Save jvmncs/02d6364cafb5e95475c927089822d5c4 to your computer and use it in GitHub Desktop.
spinning up cuda-enabled jax in a poetry project (2024 jan)
#!/bin/bash
# Add jaxlib source with priority explicit
poetry source add jaxlib https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --priority explicit
# Add jaxlib package with specified version and extras
poetry add jaxlib~=0.4.23 --extras="cuda12.cudnn89" --source=jaxlib
# Add jax source with priority explicit
poetry source add jax https://storage.googleapis.com/jax-releases/jax_releases.html --priority explicit
# Add jax package with specified version and extras
poetry add jax~=0.4.23 --extras="cuda12_pip" --source=jax
# Run a Python command to test jax installation
poetry run python -c "import jax; print(jax.devices())"
# [cuda(id=0)]
[tool.poetry]
name = "foobar"
version = "0.1.0"
description = "fooing the bar, barring the foo"
authors = ["u <u@example.com>"]
readme = "README.md"
packages = [{include = "foobar"}]
[tool.poetry.dependencies]
python = "^3.11,<3.12"
jaxlib = {version = ">=0.4.23,<0.5.0", extras = ["cuda12.cudnn89"], source = "jaxlib"}
jax = {version = "^0.4.23", extras = ["cuda12_pip"], source = "jax"}
[[tool.poetry.source]]
name = "jax"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "explicit"
[[tool.poetry.source]]
name = "jaxlib"
url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
priority = "explicit"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment