Skip to content

Instantly share code, notes, and snippets.

@zhangqiaorjc
Last active January 23, 2023 05:08
Show Gist options
  • Save zhangqiaorjc/7381b944bf1efdc7aa9897da3e453884 to your computer and use it in GitHub Desktop.
Save zhangqiaorjc/7381b944bf1efdc7aa9897da3e453884 to your computer and use it in GitHub Desktop.
Experiment planning with NV
"""Decoder-only LM scaling experiments on GPUs."""
from jax import numpy as jnp
from paxml import experiment_registry
from paxml.tasks.lm.params.lm_cloud import LmCloudSpmd
from paxml.tasks.lm.params.lm_cloud import LmCloudSpmdPipeline
from praxis import layers
# TODO(zhangqiaorjc): Might need to use pmap instead of pjit for smaller models.
# TODO(zhangqiaorjc): Configure CHECKPOINT_POLICY for all experiments.
@experiment_registry.register
class NvidiaScaling1B(LmCloudSpmd):
"""Model with 1.3B params.
Global batch size = 4 * 16 * 8 = 512
This config works on 16 hosts * 8 A100s.
"""
PERCORE_BATCH_SIZE = 4
VOCAB_SIZE = 51200
MAX_SEQ_LEN = 2048
NUM_HEADS = 32
DIMS_PER_HEAD = 64
MODEL_DIMS = 2048
HIDDEN_DIMS = MODEL_DIMS * 4
NUM_LAYERS = 24
FPROP_DTYPE = jnp.float32
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
# 128-way data parallelism.
ICI_MESH_SHAPE = [128, 1, 1]
@experiment_registry.register
class NvidiaScaling5B(LmCloudSpmd):
"""Model with 5B params.
Global batch size = 8 * 20 * 8 = 1280
This config works on 20 hosts * 8 A100s.
"""
PERCORE_BATCH_SIZE = 8
VOCAB_SIZE = 51200
MAX_SEQ_LEN = 2048
NUM_HEADS = 32
DIMS_PER_HEAD = 128
MODEL_DIMS = 4096
HIDDEN_DIMS = MODEL_DIMS * 4
NUM_LAYERS = 24
FPROP_DTYPE = jnp.float32
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
# 80-way data parallelism, 2-tensor parallelism
ICI_MESH_SHAPE = [80, 1, 2]
@experiment_registry.register
class NvidiaScaling8B(LmCloudSpmd):
"""Model with 8.3B params.
Global batch size = 4 * 16 * 8 = 1280
This config works on 16 hosts * 8 A100s.
"""
PERCORE_BATCH_SIZE = 4
VOCAB_SIZE = 51200
MAX_SEQ_LEN = 2048
NUM_HEADS = 64
DIMS_PER_HEAD = 64
MODEL_DIMS = 4096
HIDDEN_DIMS = MODEL_DIMS * 4
NUM_LAYERS = 40
FPROP_DTYPE = jnp.float32
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
# 32-way data parallelism, 4-tensor parallelism
ICI_MESH_SHAPE = [32, 1, 4]
@experiment_registry.register
class NvidiaScaling10B(LmCloudSpmd):
"""Model with 10B params.
Global batch size = 2.25 * 80 * 8 = 1440
This config works on 80 hosts * 8 A100s.
"""
PERCORE_BATCH_SIZE = 2.25
VOCAB_SIZE = 51200
MAX_SEQ_LEN = 2048
NUM_HEADS = 40
DIMS_PER_HEAD = 128
MODEL_DIMS = 5120
HIDDEN_DIMS = MODEL_DIMS * 4
NUM_LAYERS = 32
FPROP_DTYPE = jnp.float32
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
# 80-way data parallelism, 8-tensor parallelism
ICI_MESH_SHAPE = [80, 1, 8]
@experiment_registry.register
class NvidiaScaling20B(LmCloudSpmd):
"""Model with 20B params.
Global batch size = 2.25 * 80 * 8 = 1440
This config works on 80 hosts * 8 A100s.
"""
PERCORE_BATCH_SIZE = 2.25
VOCAB_SIZE = 51200
MAX_SEQ_LEN = 2048
NUM_HEADS = 48
DIMS_PER_HEAD = 128
MODEL_DIMS = 6144
HIDDEN_DIMS = MODEL_DIMS * 4
NUM_LAYERS = 44
FPROP_DTYPE = jnp.float32
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
# 80-way data parallelism, 8-tensor parallelism
ICI_MESH_SHAPE = [80, 1, 8]
@experiment_registry.register
class NvidiaScaling40B(LmCloudSpmdPipeline):
"""Model with 40B params.
Global batch size = 2.25 * 80 * 8 = 1440
This config works on 80 hosts * 8 A100s.
"""
MICROBATCH_SIZE = 2
PERCORE_BATCH_SIZE = 2.25
NUM_STAGES = 4
VOCAB_SIZE = 51200
MAX_SEQ_LEN = 2048
NUM_HEADS = 48
DIMS_PER_HEAD = 128
MODEL_DIMS = 6144
HIDDEN_DIMS = MODEL_DIMS * 4
NUM_LAYERS = 44
FPROP_DTYPE = jnp.float32
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
# 20-way data, 4-way pipeline and 8-way model parallelism.
ICI_MESH_SHAPE = [1, 20, 1, 8]
DCN_MESH_SHAPE = [4, 1, 1, 1]
@experiment_registry.register
class NvidiaScaling116B(LmCloudSpmdPipeline):
"""Model with 116B params.
Global batch size = 1.5 * 128 * 8 = 1536
This config works on 128 hosts * 8 A100s.
"""
MICROBATCH_SIZE = 2
PERCORE_BATCH_SIZE = 1.5
NUM_STAGES = 8
VOCAB_SIZE = 51200
MAX_SEQ_LEN = 2048
NUM_HEADS = 96
DIMS_PER_HEAD = 128
MODEL_DIMS = 12288
HIDDEN_DIMS = MODEL_DIMS * 4
NUM_LAYERS = 64
FPROP_DTYPE = jnp.float32
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
# 16-way data, 8-way pipeline and 8-way model parallelism.
ICI_MESH_SHAPE = [1, 16, 1, 8]
DCN_MESH_SHAPE = [8, 1, 1, 1]
@experiment_registry.register
class NvidiaScaling175B(LmCloudSpmdPipeline):
"""Model with 175B params.
Global batch size = 1.5 * 128 * 8 = 1536
This config works on 128 hosts * 8 A100s.
"""
MICROBATCH_SIZE = 1
PERCORE_BATCH_SIZE = 1.5
NUM_STAGES = 8
VOCAB_SIZE = 51200
MAX_SEQ_LEN = 2048
NUM_HEADS = 96
DIMS_PER_HEAD = 128
MODEL_DIMS = 12288
HIDDEN_DIMS = MODEL_DIMS * 4
NUM_LAYERS = 96
FPROP_DTYPE = jnp.float32
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
# 16-way data, 8-way pipeline and 8-way model parallelism.
ICI_MESH_SHAPE = [1, 16, 1, 8]
DCN_MESH_SHAPE = [8, 1, 1, 1]
@experiment_registry.register
class GoogleScaling175B(LmCloudSpmdPipeline):
"""Model with 175B params.
Global batch size = 1.5 * 128 * 8 = 1536
This config works on 128 hosts * 8 A100s.
"""
MICROBATCH_SIZE = 1
PERCORE_BATCH_SIZE = 1.5
NUM_STAGES = 8
VOCAB_SIZE = 51200
MAX_SEQ_LEN = 2048
NUM_HEADS = 96
DIMS_PER_HEAD = 128
MODEL_DIMS = 12288
HIDDEN_DIMS = MODEL_DIMS * 4
NUM_LAYERS = 96
FPROP_DTYPE = jnp.float32
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
# 16-way data, 8-way pipeline and 8-way model parallelism.
ICI_MESH_SHAPE = [1, 16, 1, 8]
DCN_MESH_SHAPE = [8, 1, 1, 1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment