Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active October 3, 2023 18:20
Show Gist options
  • Save Birch-san/0c612e63dd8a6b8d21cc66d278ccb3e9 to your computer and use it in GitHub Desktop.
Save Birch-san/0c612e63dd8a6b8d21cc66d278ccb3e9 to your computer and use it in GitHub Desktop.
Unexplained memory usage of 8-bit AdamW (paged vs unpaged)

Some weird memory usage (VRAM) is reported (by torch and by NVML) when using 8-bit AdamW, paged or unpaged.

Here we train llama 2 on 4096-token sequences, using either --optim adamw_8bit or --optim paged_adamw_8bit.
We do a full finetune using qlora.py --full-finetune, with our qlora.py fork, stepwise branch, commit 9a1045d.
We print the memory usage using HF transformers trainer's on_step_end callback. This is after optimizer.step(); model.zero_grad().

One would expect the memory usage at the end of step 1 to be the same as the end of step 2.
Yet for unpaged optimizer: memory usage leaps by 13.2GiB. End of step 1=70.4GiB, end of step 2=81.6GiB.
This appears to be a leap in PyTorch reserved memory only (32.6GiB -> 43.9GiB).

One would expect the memory usage of the paged optimizer to be equal to or lower than that of the unpaged optimizer.
Yet unpaged optimizer uses 70.4GiB 81.6GiB (steps 1 and 2), whereas paged optimizer uses 91.2GiB.

By the end of step 2: both optimizers moreorless agree on the amount of Torch reserved memory (unpaged=43.9GiB, paged=45.2GiB).
But throughout: the unpaged optimizer uses far more Torch allocated memory (37.7GiB) than the paged optimizer (25.2GiB).

The step 1/2 thing is probably not related to gradient accumulation. We are using gradient accumulation, but gradients are accumulated during microsteps; on_step_end is called after the microsteps.

I have tried modifying my MemoryUsageCallback to use torch.cuda.synchronize(). This does not alter the result.

I ran the program like so (device_map='auto' over 2xA40 on Linux, CUDA 12.1, compute capability 8.6):

ACCELERATE_MIXED_PRECISION=bf16 python -m qlora \
--device_map_auto \
--disable_tqdm True \
--use_auth_token True \
--model_name_or_path meta-llama/Llama-2-7b-chat-hf \
--use_flash_llama \
--trust_remote_code \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--max_memory_MB 48000 --simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 2048 \
--target_max_len 2048 \
--gradient_accumulation_steps 2 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--optim adamw_8bit \
--learning_rate 0.0002 \
--save_steps 196 \
--save_total_limit 1 \
--max_steps 256 \
--evaluation_strategy steps \
--eval_steps 196 \
--measure_memory \
--terminate_after_step 8 \
--bits 32 \
--full_finetune

The memory readings I get (from Torch and from NVML) are as follows.

adamw_8bit:

step 1
  NVML memory stats (used+reserved, all processes):
    Device 0: Used 36314MiB / 49140MiB
    Device 1: Used 37746MiB / 49140MiB
    Overall: Used 74061MiB / 98280MiB
  Torch memory stats (allocated, reserved):
    Device 0: Used 35304MiB (Allocated: 19323MiB, Reserved 15980MiB)
    Device 1: Used 36736MiB (Allocated: 19323MiB, Reserved 17412MiB)
    Overall: Used 72040MiB (Allocated: 38647MiB, Reserved 33392MiB)
step 2
  NVML memory stats (used+reserved, all processes):
    Device 0: Used 42086MiB / 49140MiB
    Device 1: Used 43520MiB / 49140MiB
    Overall: Used 85607MiB / 98280MiB
  Torch memory stats (allocated, reserved):
    Device 0: Used 41076MiB (Allocated: 19323MiB, Reserved 21752MiB)
    Device 1: Used 42510MiB (Allocated: 19323MiB, Reserved 23186MiB)
    Overall: Used 83586MiB (Allocated: 38647MiB, Reserved 44938MiB)
step 3
  NVML memory stats (used+reserved, all processes):
    Device 0: Used 42086MiB / 49140MiB
    Device 1: Used 43520MiB / 49140MiB
    Overall: Used 85607MiB / 98280MiB
  Torch memory stats (allocated, reserved):
    Device 0: Used 41076MiB (Allocated: 19323MiB, Reserved 21752MiB)
    Device 1: Used 42510MiB (Allocated: 19323MiB, Reserved 23186MiB)
    Overall: Used 83586MiB (Allocated: 38647MiB, Reserved 44938MiB)

paged_adamw_8bit:

step 1
  NVML memory stats (used+reserved, all processes):
    Device 0: Used 49138MiB / 49140MiB
    Device 1: Used 44274MiB / 49140MiB
    Overall: Used 93413MiB / 98280MiB
  Torch memory stats (allocated, reserved):
    Device 0: Used 35304MiB (Allocated: 12897MiB, Reserved 22406MiB)
    Device 1: Used 36736MiB (Allocated: 12897MiB, Reserved 23838MiB)
    Overall: Used 72040MiB (Allocated: 25795MiB, Reserved 46244MiB)
step 2
  NVML memory stats (used+reserved, all processes):
    Device 0: Used 49138MiB / 49140MiB
    Device 1: Used 44276MiB / 49140MiB
    Overall: Used 93415MiB / 98280MiB
  Torch memory stats (allocated, reserved):
    Device 0: Used 35306MiB (Allocated: 12897MiB, Reserved 22408MiB)
    Device 1: Used 36738MiB (Allocated: 12897MiB, Reserved 23840MiB)
    Overall: Used 72044MiB (Allocated: 25795MiB, Reserved 46248MiB)
step 3
  NVML memory stats (used+reserved, all processes):
    Device 0: Used 49138MiB / 49140MiB
    Device 1: Used 44276MiB / 49140MiB
    Overall: Used 93415MiB / 98280MiB
  Torch memory stats (allocated, reserved):
    Device 0: Used 35306MiB (Allocated: 12897MiB, Reserved 22408MiB)
    Device 1: Used 36738MiB (Allocated: 12897MiB, Reserved 23840MiB)
    Overall: Used 72044MiB (Allocated: 25795MiB, Reserved 46248MiB)

Versions:

  • bitsandbytes 0.41.1
  • transformers 4.32.1
  • accelerate 0.22.0
  • torch 2.1.0.dev20230802+cu121
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment