Skip to content

Instantly share code, notes, and snippets.

@cakiki
Forked from skye/tpu_topology_env_vars.py
Created July 4, 2021 11:26
Show Gist options
  • Save cakiki/80fc11d75a9c1f03dab831ea0cf96ddb to your computer and use it in GitHub Desktop.
Save cakiki/80fc11d75a9c1f03dab831ea0cf96ddb to your computer and use it in GitHub Desktop.
You can use these environment variables to run a Python process on a subset of the TPU cores on a Cloud TPU VM. This allows running multiple TPU processes at the same time, since only one process can access a given TPU core at a time. Note that in JAX, 1 TPU core = 1 TpuDevice as reported by `jax.devices()`.
# 4x 1 chip (2 cores) per process:
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "1,1,1"
os.environ["TPU_HOST_BOUNDS"] = "1,1,1"
# Different per process:
os.environ["TPU_VISIBLE_DEVICES"] = "0" # "1", "2", "3"
# Pick a unique port per process
os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = "localhost:8476"
os.environ["TPU_MESH_CONTROLLER_PORT"] = "8476"
# 1-liner for bash: TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476
# 2x 2 chips (4 cores) per process:
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "1,2,1"
os.environ["TPU_HOST_BOUNDS"] = "1,1,1"
# Different per process:
os.environ["TPU_VISIBLE_DEVICES"] = "0,1" # "2,3"
# Pick a unique port per process
os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = "localhost:8476"
os.environ["TPU_MESH_CONTROLLER_PORT"] = "8476"
# 1-liner for bash: TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment