Skip to content

Instantly share code, notes, and snippets.

Last active November 26, 2020 11:23
Show Gist options
  • Save skye/7d96a58028e1efc4d5a62cd5af9b342a to your computer and use it in GitHub Desktop.
Save skye/7d96a58028e1efc4d5a62cd5af9b342a to your computer and use it in GitHub Desktop.
My Cloud TPU pod setup script for developing pod-related jax features. I run this on every host in the pod via terminal broadcasting. The final commands commented out at the bottom are ones I run manually, maybe there's a better way.
set -eux
PYTHON_VERSION=cp36 # Supported python versions: cp36, cp37, cp38
pip install --upgrade --user$PYTHON_VERSION-none-manylinux2010_x86_64.whl
sudo tee -a /usr/local/lib/python3.6/dist-packages/ > /dev/null <<EOF
import os
import requests
def get_metadata(key):
return requests.get(
'Metadata-Flavor': 'Google'
worker_id = get_metadata('agent-worker-number')
accelerator_type = get_metadata('accelerator-type')
worker_network_endpoints = get_metadata('worker-network-endpoints')
os.environ['CLOUD_TPU_TASK_ID'] = worker_id
os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1'
accelerator_type_to_host_bounds = {
'v2-8': '1,1,1',
'v2-32': '2,2,1',
'v2-128': '4,4,1',
'v2-256': '4,8,1',
'v2-512': '8,8,1',
'v3-8': '1,1,1',
'v3-32': '2,2,1',
'v3-128': '4,4,1',
'v3-256': '4,8,1',
'v3-512': '8,8,1',
'v3-1024': '8,16,1',
'v3-2048': '16,16,1',
os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[
os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = worker_network_endpoints.split(
',')[0].split(':')[2] + ':8476'
os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476'
pip install ipython ipdb
export PATH=/home/skyewm/.local/bin:$PATH
sudo apt install sshfs -y
git config --global "Skye Wanderman-Milne"
git config --global
if [ "$HOST_ID" -eq "0" ]
sudo apt install emacs -y
git clone
cd jax
git remote add skye
# optional: check out your dev branch
mkdir sshfs
ssh-keygen -t rsa -b 4096 -C my_tpu_pod -N '' -f ~/.ssh/id_rsa
cat .ssh/
# emacs .ssh/authorized_keys
# sshfs ...: sshfs # Use _internal_ IP address of host 0!
# cd sshfs
# cd jax
# pip install --upgrade -e .
Copy link

Some notes

  • Maybe it is useful to replace the values specific to your case with variables (e.g., git username, git repo).
  • This file has extension ".py", but I think it is actually a bash script, is that correct?
  • Do you run the comments on the bottom on all hosts, or only on host 0?
  • You clone the JAX repo and install it from source. Is this necessary or can we also pip install it if the code we are running is outside of JAX? Or are there some features that haven't been pushed to pypi yet?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment