Some weird memory usage (VRAM) is reported (by torch and by NVML) when using 8-bit AdamW, paged or unpaged.
Here we train llama 2 on 4096-token sequences, using either --optim adamw_8bit
or --optim paged_adamw_8bit
.
We do a full finetune using qlora.py --full-finetune
, with our qlora.py fork, stepwise branch, commit 9a1045d
.
We print the memory usage using HF transformers trainer's on_step_end
callback. This is after optimizer.step(); model.zero_grad()
.
One would expect the memory usage at the end of step 1 to be the same as the end of step 2.
Yet for unpaged optimizer: memory usage leaps by 13.2GiB. End of step 1=70.4GiB, end of step 2=81.6GiB.
This appears to be a leap in PyTorch reserved memory only (32.6GiB -> 43.9GiB).
One would expect the memory usage of the paged optimizer to be equal to or lower than that of the unpaged optimizer.
Yet unpaged optimizer uses 70.4GiB 81.6GiB (steps 1 and 2), whereas paged optimizer uses 91.2GiB.
By the end of step 2: both optimizers moreorless agree on the amount of Torch reserved memory (unpaged=43.9GiB, paged=45.2GiB).
But throughout: the unpaged optimizer uses far more Torch allocated memory (37.7GiB) than the paged optimizer (25.2GiB).
The step 1/2 thing is probably not related to gradient accumulation. We are using gradient accumulation, but gradients are accumulated during microsteps; on_step_end
is called after the microsteps.
I have tried modifying my MemoryUsageCallback
to use torch.cuda.synchronize()
. This does not alter the result.
I ran the program like so (device_map='auto'
over 2xA40 on Linux, CUDA 12.1, compute capability 8.6):
ACCELERATE_MIXED_PRECISION=bf16 python -m qlora \
--device_map_auto \
--disable_tqdm True \
--use_auth_token True \
--model_name_or_path meta-llama/Llama-2-7b-chat-hf \
--use_flash_llama \
--trust_remote_code \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--max_memory_MB 48000 --simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 2048 \
--target_max_len 2048 \
--gradient_accumulation_steps 2 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--optim adamw_8bit \
--learning_rate 0.0002 \
--save_steps 196 \
--save_total_limit 1 \
--max_steps 256 \
--evaluation_strategy steps \
--eval_steps 196 \
--measure_memory \
--terminate_after_step 8 \
--bits 32 \
--full_finetune
The memory readings I get (from Torch and from NVML) are as follows.
adamw_8bit:
step 1
NVML memory stats (used+reserved, all processes):
Device 0: Used 36314MiB / 49140MiB
Device 1: Used 37746MiB / 49140MiB
Overall: Used 74061MiB / 98280MiB
Torch memory stats (allocated, reserved):
Device 0: Used 35304MiB (Allocated: 19323MiB, Reserved 15980MiB)
Device 1: Used 36736MiB (Allocated: 19323MiB, Reserved 17412MiB)
Overall: Used 72040MiB (Allocated: 38647MiB, Reserved 33392MiB)
step 2
NVML memory stats (used+reserved, all processes):
Device 0: Used 42086MiB / 49140MiB
Device 1: Used 43520MiB / 49140MiB
Overall: Used 85607MiB / 98280MiB
Torch memory stats (allocated, reserved):
Device 0: Used 41076MiB (Allocated: 19323MiB, Reserved 21752MiB)
Device 1: Used 42510MiB (Allocated: 19323MiB, Reserved 23186MiB)
Overall: Used 83586MiB (Allocated: 38647MiB, Reserved 44938MiB)
step 3
NVML memory stats (used+reserved, all processes):
Device 0: Used 42086MiB / 49140MiB
Device 1: Used 43520MiB / 49140MiB
Overall: Used 85607MiB / 98280MiB
Torch memory stats (allocated, reserved):
Device 0: Used 41076MiB (Allocated: 19323MiB, Reserved 21752MiB)
Device 1: Used 42510MiB (Allocated: 19323MiB, Reserved 23186MiB)
Overall: Used 83586MiB (Allocated: 38647MiB, Reserved 44938MiB)
paged_adamw_8bit:
step 1
NVML memory stats (used+reserved, all processes):
Device 0: Used 49138MiB / 49140MiB
Device 1: Used 44274MiB / 49140MiB
Overall: Used 93413MiB / 98280MiB
Torch memory stats (allocated, reserved):
Device 0: Used 35304MiB (Allocated: 12897MiB, Reserved 22406MiB)
Device 1: Used 36736MiB (Allocated: 12897MiB, Reserved 23838MiB)
Overall: Used 72040MiB (Allocated: 25795MiB, Reserved 46244MiB)
step 2
NVML memory stats (used+reserved, all processes):
Device 0: Used 49138MiB / 49140MiB
Device 1: Used 44276MiB / 49140MiB
Overall: Used 93415MiB / 98280MiB
Torch memory stats (allocated, reserved):
Device 0: Used 35306MiB (Allocated: 12897MiB, Reserved 22408MiB)
Device 1: Used 36738MiB (Allocated: 12897MiB, Reserved 23840MiB)
Overall: Used 72044MiB (Allocated: 25795MiB, Reserved 46248MiB)
step 3
NVML memory stats (used+reserved, all processes):
Device 0: Used 49138MiB / 49140MiB
Device 1: Used 44276MiB / 49140MiB
Overall: Used 93415MiB / 98280MiB
Torch memory stats (allocated, reserved):
Device 0: Used 35306MiB (Allocated: 12897MiB, Reserved 22408MiB)
Device 1: Used 36738MiB (Allocated: 12897MiB, Reserved 23840MiB)
Overall: Used 72044MiB (Allocated: 25795MiB, Reserved 46248MiB)
Versions:
- bitsandbytes 0.41.1
- transformers 4.32.1
- accelerate 0.22.0
- torch 2.1.0.dev20230802+cu121