Skip to content

Instantly share code, notes, and snippets.

@float-trip
Last active July 16, 2023 10:56
Show Gist options
  • Save float-trip/679019a23f246b17d2dff9e2cf55c387 to your computer and use it in GitHub Desktop.
Save float-trip/679019a23f246b17d2dff9e2cf55c387 to your computer and use it in GitHub Desktop.

Finetuning guide for MPT

This assumes you're starting with a jsonl file containing your data in the format: { "text": "..." }. If your dataset follows an instruct/chat format, see alternate data prep here and here, as well as the train_loader/eval_loader here.

Example 30b config

This config gets ~641 tokens/sec/GPU. For other sequence lengths, see these benchmarks as a reference. Although I needed to pick smaller batch sizes (likely due to using adamw instead of lion), I found I generally matched the throughput from that table.

Note that you will not be able to resume runs using load_path from the checkpoints saved by this config. A possible workaround may be converting to the HF format and resuming from that. See this issue for more info.

data_local: /splits
data_remote:
max_seq_len: 8192
global_seed: 17

# Run Name
run_name: mpt-30b-ft

# Model
model:
  name: hf_causal_lm
  pretrained: true
  pretrained_model_name_or_path: mosaicml/mpt-30b
  config_overrides:
    max_seq_len: ${max_seq_len}
    attn_config:
      attn_impl: triton
      attn_uses_sequence_id: false

# Tokenizer
tokenizer:
  name: mosaicml/mpt-30b
  kwargs:
    model_max_length: ${max_seq_len}

# Dataloaders
train_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: train
    shuffle: true
    max_seq_len: ${max_seq_len}
    shuffle_seed: ${global_seed}
  drop_last: true
  num_workers: 8

eval_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: validation
    shuffle: false
    max_seq_len: ${max_seq_len}
    shuffle_seed: ${global_seed}
  drop_last: false
  num_workers: 8

# Optimization
scheduler:
  name: cosine_with_warmup
  # Change depending on the number of tokens you have.
  # No idea what's best - the finetuning guide for mesh-transformer-jax suggests 5%-10% of your total batch count.
  t_warmup: 40ba
  alpha_f: 0.1

optimizer:
  name: decoupled_adamw
  lr: 7.0e-5
  betas:
  - 0.9
  - 0.99
  eps: 1.0e-08
  weight_decay: 0.0

algorithms:
  gradient_clipping:
    clipping_type: norm
    clipping_threshold: 1.0

max_duration: 4ep
eval_interval: 1ep
eval_first: false
eval_subset_num_batches: -1
global_train_batch_size: 64

# System
seed: ${global_seed}
device_eval_batch_size: 1
device_train_microbatch_size: 1
precision: amp_bf16

# FSDP
fsdp_config:
  sharding_strategy: FULL_SHARD
  mixed_precision: PURE
  activation_checkpointing: true
  activation_checkpointing_reentrant: false
  activation_cpu_offload: false
  limit_all_gathers: true
  verbose: false
  state_dict_type: sharded

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
  speed_monitor:
    window_size: 10
  lr_monitor: {}
  memory_monitor: {}
  runtime_estimator: {}
  mono_ckpt_saver:
    save_folder: /checkpoints/{run_name}
    batch_interval: 100 # Save frequency. Change this!

loggers:
  wandb:
    project: mpt-30b
Example 7b config
data_local: /splits
data_remote:
max_seq_len: 8192
global_seed: 17

# Run Name
run_name: mpt-7b-ft

# Model
model:
  name: hf_causal_lm
  pretrained: true
  pretrained_model_name_or_path: mosaicml/mpt-7b
  config_overrides:
    max_seq_len: ${max_seq_len}
    attn_config:
      attn_impl: triton
      attn_uses_sequence_id: false

# Tokenizer
tokenizer:
  name: mosaicml/mpt-7b
  kwargs:
    model_max_length: ${max_seq_len}

# Dataloaders
train_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: train
    shuffle: true
    max_seq_len: ${max_seq_len}
    shuffle_seed: ${global_seed}
  drop_last: true
  num_workers: 8

eval_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: validation
    shuffle: false
    max_seq_len: ${max_seq_len}
    shuffle_seed: ${global_seed}
  drop_last: false
  num_workers: 8

# Optimization
scheduler:
  name: cosine_with_warmup
  # Change depending on the number of tokens you have.
  # No idea what's best - the finetuning guide for mesh-transformer-jax suggests 5%-10% of your total batch count.
  t_warmup: 40ba
  alpha_f: 0.1

optimizer:
  name: decoupled_adamw
  lr: 7.0e-5
  betas:
  - 0.9
  - 0.99
  eps: 1.0e-08
  weight_decay: 0.0

algorithms:
  gradient_clipping:
    clipping_type: norm
    clipping_threshold: 1.0

max_duration: 4ep
eval_interval: 1ep
eval_first: false
eval_subset_num_batches: -1
global_train_batch_size: 64

# System
seed: ${global_seed}
device_eval_batch_size: 4
device_train_microbatch_size: 4
precision: amp_bf16

# FSDP
fsdp_config:
  sharding_strategy: FULL_SHARD
  mixed_precision: PURE
  activation_checkpointing: true
  activation_checkpointing_reentrant: false
  activation_cpu_offload: false
  limit_all_gathers: true
  verbose: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
  speed_monitor:
    window_size: 10
  lr_monitor: {}
  memory_monitor: {}
  runtime_estimator: {}

loggers:
  wandb:
    project: mpt-7b
  
save_interval: 100ba
save_num_checkpoints_to_keep: 25
save_folder: /checkpoints/{run_name}

Wait around for an 8xA100 80gb instance from LambdaLabs to be available:

# `apt/brew install jq`
# `brew install terminal-notifier`, or replace with platform equivalent
api_key=""
ssh_key=""
target_instance="gpu_8x_a100_80gb_sxm4"

while true
do
  capacity=$(curl -s \
    -H "Authorization: Basic ${api_key}" \
    "https://cloud.lambdalabs.com/api/v1/instance-types" | \
    jq ".data.${target_instance}.regions_with_capacity_available")

  if [[ "${capacity}" != "[]" ]] && [[ "${capacity}" != "null" ]]; then
    echo "Trying to launch instance..."
    region=$(jq -r ".[].name" <<< "${capacity}")
    
    launch_results=$(curl -X POST -H 'Content-Type: application/json' -H "Authorization: Basic ${api_key}" \
      -d '{
        "region_name": "'${region}'",
        "instance_type_name": "'${target_instance}'",
        "ssh_key_names": ["'${ssh_key}'"]
      }' \
      "https://cloud.lambdalabs.com/api/v1/instance-operations/launch")

    instance_ids=$(jq -r ".data.instance_ids" <<< "${launch_results}")

    if [[ "${instance_ids}" != "null" ]]; then
      terminal-notifier -message "Instance is up!"
      exit
    fi
  fi

  sleep 10
done

SSH into the instance:

# Run Mosaic's Docker image.
sudo docker run --gpus all -dit --shm-size=5gb --name mosaic mosaicml/llm-foundry:2.0.1_cu118-latest

# Open a shell on Docker:
sudo docker exec -it mosaic /bin/bash
sudo apt update && sudo apt -y install magic-wormhole parallel

# Transfer "samples.jsonl" and "config.yaml".
# elsewhere: wormhole send ...
wormhole receive ...

# Then install dependencies:
git clone https://github.com/mosaicml/llm-foundry
cd llm-foundry
pip install -e ".[gpu]"

# See https://github.com/mosaicml/llm-foundry/issues/367
sed -i '518s/.*/       if True:/' /usr/lib/python3/dist-packages/torch/distributed/fsdp/_state_dict_utils.py

# Create train/validation sets.
total_lines=$(wc -l < /samples.jsonl)
shuf /samples.jsonl > /shuffled.jsonl
validation_lines=$((total_lines / 5))
training_lines=$((total_lines - validation_lines + 1))
head -n ${training_lines} /shuffled.jsonl > /train.jsonl
tail -n +$((${training_lines} + 1)) /shuffled.jsonl > /validation.jsonl

mkdir /splits
convert_dataset() {
    python scripts/data_prep/convert_dataset_json.py \
        --path /$1.jsonl \
        --out_root /splits/$1 \
        --split train \
        --concat_tokens 8192 \
        --tokenizer "EleutherAI/gpt-neox-20b" \
        --eos_text '<|endoftext|>'
}

export -f convert_dataset
parallel convert_dataset ::: train validation

# Train!
composer scripts/train/train.py /config.yaml

# Upload to HuggingFace.
export HUGGING_FACE_HUB_TOKEN=your-auth-token
python scripts/inference/convert_composer_to_hf.py \
  --composer_path /checkpoints/[checkpoint].pt \
  --hf_output_path mpt-30b-ft-hf \
  --output_precision bf16 \
  --hf_repo_for_upload user-org/repo-name
@alanxmay
Copy link

Thanks for your work, I tried mpt-30b config on A100(80G) * 8 meet CUDA OOM when saving checkpoints. Did you have any suggestions?

@float-trip
Copy link
Author

Yeah, see this issue: mosaicml/llm-foundry#367

Look for the "sed" line above to fix it.

@alanxmay
Copy link

Thanks, will test it right now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment