Skip to content

Instantly share code, notes, and snippets.

@skye
skye / gist:597bfbbbf42e8f0fc541007ad3a52fd8
Created August 31, 2022 18:48
Turning on fine-grained PJRT client logging
skyewm@t1v-n-0c604497-w-0:~$ TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=pjrt_stream_executor_client=1 python3 -c "import jax; print(jax.numpy.add(1,1))"
2022-08-31 18:03:35.722123: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:173] XLA service 0x2c13080 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2022-08-31 18:03:35.722165: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:181] StreamExecutor device (0): Interpreter, <undefined>
2022-08-31 18:03:35.785537: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:181] TfrtCpuClient created.
2022-08-31 18:03:35.786313: I external/org_tensorflow/tensorflow/core/tpu/tpu_initializer_helper.cc:262] Libtpu path is: /home/skyewm/.local/lib/python3.8/site-packages/libtpu/libtpu.so
2022-08-31 18:03:41.758624: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:173] XLA service 0x536d500 initialized for platform TPU (this does not guarantee th
$ TF_CPP_VMODULE=bfc_allocator=1,stream_executor_pimpl=1 XLA_PYTHON_CLIENT_PREALLOCATE=false TF_CPP_MIN_LOG_LEVEL=0 python3 -c "import jax; jax.numpy.add(1,1)"
2021-08-13 11:12:50.269963: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:171] XLA service 0x28efc10 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2021-08-13 11:12:50.269987: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): Interpreter, <undefined>
2021-08-13 11:12:50.274779: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:163] TfrtCpuClient created.
ERROR:absl:preallocate false
2021-08-13 11:12:50.384306: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:171] XLA service 0x2a6e610 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2021-08-13 11:12:50.384362: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179] StreamEx
@skye
skye / tpu_topology_env_vars.py
Last active June 22, 2024 08:09
You can use these environment variables to run a Python process on a subset of the TPU cores on a Cloud TPU VM. This allows running multiple TPU processes at the same time, since only one process can access a given TPU chip at a time. Note that on TPU v2 and v3, 1 TPU chip = 2 TpuDevice as reported by `jax.devices()` (8 devices total). On v4, 1 …
# ==== Non-communicating processes
# 4x 1 chip per process:
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,1,1"
os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
# Different per process:
os.environ["TPU_VISIBLE_DEVICES"] = "0" # "1", "2", "3"
# Pick a unique port per process
os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = "localhost:8476"
os.environ["TPU_MESH_CONTROLLER_PORT"] = "8476"
@skye
skye / cloud_tpu_pod_setup.py
Last active November 26, 2020 11:23
My Cloud TPU pod setup script for developing pod-related jax features. I run this on every host in the pod via terminal broadcasting. The final commands commented out at the bottom are ones I run manually, maybe there's a better way.
set -eux
HOST_ID="${HOSTNAME: -1}"
PYTHON_VERSION=cp36 # Supported python versions: cp36, cp37, cp38
pip install --upgrade --user https://storage.googleapis.com/jax-releases/tpu/jaxlib-0.1.55+tpu-$PYTHON_VERSION-none-manylinux2010_x86_64.whl
sudo tee -a /usr/local/lib/python3.6/dist-packages/jax_pod_setup.py > /dev/null <<EOF
import os