Skip to content

Instantly share code, notes, and snippets.

@skye
Last active June 22, 2024 08:09
Show Gist options
  • Save skye/f82ba45d2445bb19d53545538754f9a3 to your computer and use it in GitHub Desktop.
Save skye/f82ba45d2445bb19d53545538754f9a3 to your computer and use it in GitHub Desktop.
You can use these environment variables to run a Python process on a subset of the TPU cores on a Cloud TPU VM. This allows running multiple TPU processes at the same time, since only one process can access a given TPU chip at a time. Note that on TPU v2 and v3, 1 TPU chip = 2 TpuDevice as reported by `jax.devices()` (8 devices total). On v4, 1 …
# ==== Non-communicating processes
# 4x 1 chip per process:
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,1,1"
os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
# Different per process:
os.environ["TPU_VISIBLE_DEVICES"] = "0" # "1", "2", "3"
# Pick a unique port per process
os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = "localhost:8476"
os.environ["TPU_MESH_CONTROLLER_PORT"] = "8476"
# 1-liner for bash: TPU_CHIPS_PER_PROCESS_BOUNDS=1,1,1 TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476
# 2x 2 chips per process:
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,2,1"
os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
# Different per process:
os.environ["TPU_VISIBLE_DEVICES"] = "0,1" # "2,3"
# Pick a unique port per process
os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = "localhost:8476"
os.environ["TPU_MESH_CONTROLLER_PORT"] = "8476"
# 1-liner for bash: TPU_CHIPS_PER_PROCESS_BOUNDS=1,2,1 TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476
# 1x 4 chips for one process per host (default on v2-8, v3-8, v4-8):
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "2,2,1"
os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
os.environ["TPU_VISIBLE_DEVICES"] = "0,1,2,3"
# 1-liner for bash: TPU_CHIPS_PER_PROCESS_BOUNDS=2,2,1 TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0,1,2,3
# ==== Communicating processes
# 4x 1 chip per process:
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,1,1"
os.environ["TPU_PROCESS_BOUNDS"] = "2,2,1"
os.environ["TPU_PROCESS_ADDRESSES"] = "localhost:8476,localhost:8477,localhost:8478,localhost:8479"
os.environ["TPU_VISIBLE_DEVICES"] = "0" # "1", "2", "3"
os.environ["TPU_PROCESS_PORT"] = "8476" # "8477", "8478", "8479"
os.environ["CLOUD_TPU_TASK_ID"] = "0" # "1", "2", "3"
# 1-liner for bash: TPU_CHIPS_PER_PROCESS_BOUNDS=1,1,1 TPU_PROCESS_BOUNDS=2,2,1 TPU_PROCESS_ADDRESSES=localhost:8476,localhost:8477,localhost:8478,localhost:8479 TPU_VISIBLE_DEVICES=0 TPU_PROCESS_PORT=8476 CLOUD_TPU_TASK_ID=0
# 2x 2 chips per process:
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,2,1"
os.environ["TPU_PROCESS_BOUNDS"] = "2,1,1"
os.environ["TPU_PROCESS_ADDRESSES"] = "localhost:8476,localhost:8477"
os.environ["TPU_VISIBLE_DEVICES"] = "0,1" # "2,3"
os.environ["TPU_PROCESS_PORT"] = "8476" # "8477"
os.environ["CLOUD_TPU_TASK_ID"] = "0" # "1"
# 1-liner for bash: TPU_CHIPS_PER_PROCESS_BOUNDS=1,2,1 TPU_PROCESS_BOUNDS=2,1,1 TPU_PROCESS_ADDRESSES=localhost:8476,localhost:8477 TPU_VISIBLE_DEVICES=0,1 TPU_PROCESS_PORT=8476 CLOUD_TPU_TASK_ID=0
@zhongwen
Copy link

Thanks Skye! I suppose the first line is a typo and it should be

# *1*x 1 chip (2 cores) per process:

@skye
Copy link
Author

skye commented Sep 28, 2021

What I meant is that you can run up to 4 single-chip processes on a single 2x2 TPU VM, by having 1 process per chip.

@zhongwen
Copy link

What I meant is that you can run up to 4 single-chip processes on a single 2x2 TPU VM, by having 1 process per chip.

I see. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment