Skip to content

Instantly share code, notes, and snippets.

View zhangqiaorjc's full-sized avatar

Qiao Zhang zhangqiaorjc

View GitHub Profile
@zhangqiaorjc
zhangqiaorjc / sincos-remat-example.ipynb
Created November 10, 2023 05:27
sincos remat example.ipynb
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@zhangqiaorjc
zhangqiaorjc / bfloat16-training.ipynb
Created April 22, 2023 17:04
bfloat16-training.ipynb
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@zhangqiaorjc
zhangqiaorjc / collective_matmul_allgather_lhs_non_contracting.py
Last active September 26, 2023 16:56
## Collective Matmul Example A: (M, K) B: (K, N) C = A @ B A[0] sharded by 'x', B replicated and C replicated
import numpy as np
import os, re
import jax
from jax.experimental import maps
from jax.experimental import pjit
import jax.numpy as jnp
from jax.experimental import mesh_utils
from absl import flags
@zhangqiaorjc
zhangqiaorjc / gpu_scaling.py
Last active January 23, 2023 05:08
Experiment planning with NV
"""Decoder-only LM scaling experiments on GPUs."""
from jax import numpy as jnp
from paxml import experiment_registry
from paxml.tasks.lm.params.lm_cloud import LmCloudSpmd
from paxml.tasks.lm.params.lm_cloud import LmCloudSpmdPipeline
from praxis import layers
# TODO(zhangqiaorjc): Might need to use pmap instead of pjit for smaller models.
@zhangqiaorjc
zhangqiaorjc / copy-of-jax-transformer-model-for-fp8-no-decode-cache-shared-with-nvidia.ipynb
Created December 14, 2022 07:32
Copy of JAX Transformer model for fp8 (no decode cache) shared with NVIDIA.ipynb
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@zhangqiaorjc
zhangqiaorjc / copy-of-mnist-fp8-for-sharing-with-nvidia.ipynb
Last active December 14, 2022 07:25
Copy of mnist fp8 for sharing with NVIDIA.ipynb
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@zhangqiaorjc
zhangqiaorjc / copy-of-mnist-fp8-for-sharing-with-nvidia.ipynb
Created December 7, 2022 06:26
Copy of mnist fp8 for sharing with NVIDIA.ipynb
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
"""Runs a simple mnist model with fake FP8. FP8 scaling is used.
The HLO can be dumped by setting the environment variable:
XLA_FLAGS='--xla_dump_disable_metadata=true --xla_dump_to=/tmp/hlo'
"""
import tensorflow as tf
from absl.testing import absltest
from absl import logging
import jax
import jax.numpy as jnp
def amax(x):
return jnp.max(jnp.abs(x))
@zhangqiaorjc
zhangqiaorjc / make_hlo.py
Created February 25, 2022 06:36
make_hlo
def make_hlo(f, optimize=False, metadata=False, platform=None):
"""Utility function for printing JAX-emitted HLO and XLA-compiled HLO.
Args:
f: jax function to return hlo for.
optimize: bool: whether to return platform-specific, XLA-optimized HLO
metadata: bool: whether to include JAX metadata information
platform: Optional[str]: None, 'cpu','gpu','tpu' - platform to compile for,
None uses default.