Skip to content

Instantly share code, notes, and snippets.

@main-horse
Last active December 21, 2024 16:27
Show Gist options
  • Save main-horse/bae3a07643e204a48b926870deeef804 to your computer and use it in GitHub Desktop.
Save main-horse/bae3a07643e204a48b926870deeef804 to your computer and use it in GitHub Desktop.
showcasing the deadlock behavior of zero2 vs zero3 on deepspeed vs fsdp
# uv pip install deepspeed torch transformers
import os
import torch
import deepspeed
from transformers import AutoModelForCausalLM
def create_config(zero_stage=2):
return {
"train_batch_size": 2, # global_batch_size = micro_batch_size * grad_acc * dp_size
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": {
"stage": zero_stage,
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_bucket_size": 5e7,
"stage3_prefetch_bucket_size": 5e7,
"stage3_param_persistence_threshold": 1e5
},
"fp16": {"enabled": True, "initial_scale_power": 12},
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False
}
def main():
zero_stage = int(os.getenv("ZEROSTAGE", "2"))
local_rank = int(os.getenv('LOCAL_RANK', '0'))
# Initialize model
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m", torch_dtype=torch.float16)
engine, _, _, _ = deepspeed.initialize(
model=model,
optimizer=torch.optim.AdamW(model.parameters(), lr=1e-5),
config=create_config(zero_stage),
model_parameters=model.parameters()
)
# Create dummy input and target
input_ids = torch.randint(0, 1000, (1, 32), device=local_rank)
labels = torch.randint(0, 1000, (1, 32), device=local_rank)
# Perform initial forward-backward pass on all ranks
print(f"Performing initial forward-backward pass on rank {local_rank}")
outputs = engine(input_ids=input_ids, labels=labels)
engine.backward(outputs.loss)
engine.step()
print(f"Completed initial forward-backward pass on rank {local_rank}")
torch.distributed.barrier()
# Only run on rank 0
if local_rank == 0:
print(f"Running inference on rank {local_rank}")
with torch.no_grad():
_ = engine(input_ids=input_ids)
print("First forward pass completed")
_ = engine(input_ids=input_ids)
print("Second forward pass completed")
# Check parameter status
for n, p in engine.module.named_parameters():
if hasattr(p, 'ds_status'):
print(f"Parameter {n}: Status = {p.ds_status}")
if local_rank == 0: print("Test completed successfully")
if __name__ == "__main__": main()
# ZEROSTAGE=2 deepspeed --num_gpus=2 test_ds.py # <-- won't deadlock
# ZEROSTAGE=3 deepspeed --num_gpus=2 test_ds.py # <-- will deadlock
# uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
import os
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
from torch.distributed.tensor import DeviceMesh
# do not use torch<=2.5.1 or this will fail:
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, FullyShardedDataParallel, MixedPrecision, ShardingStrategy
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_demo(rank, world_size):
print(f"Running on rank {rank}")
setup(rank, world_size)
# Create model and move to GPU
model = SimpleModel().cuda(rank)
# Create 2D mesh for FSDP
mesh = DeviceMesh("cuda", torch.arange(world_size))
# Wrap model with FSDP, explicitly setting reshard_after_forward=False
model = fully_shard(
model,
mesh=mesh,
reshard_after_forward=False,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
)
) if len(sys.argv) > 1 else FullyShardedDataParallel(
model,
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
mixed_precision=MixedPrecision(torch.float16, torch.float16),
)
# Test regular forward/backward on all ranks
batch = torch.randn(20, 10).cuda(rank)
output = model(batch)
loss = output.sum()
loss.backward()
print(f"Rank {rank}: Completed regular forward/backward")
dist.barrier()
print(f"Rank {rank}: Passed first barrier")
# Now test single-rank forward
if rank == 0:
print("Rank 0: Starting single-rank forward")
single_batch = torch.randn(20, 10).cuda(rank)
with torch.no_grad():
single_output = model(single_batch)
print("Rank 0: Completed single-rank forward")
dist.barrier()
print(f"Rank {rank}: Passed final barrier")
cleanup()
if __name__ == "__main__":
mp.spawn(run_demo, args=(2,), nprocs=2, join=True)
# torchrun --nproc-per-node 2 test_fs.py 2 # <-- fsdp2 test, will hang
# torchrun --nproc-per-node 2 test_fs.py # <-- fsdp1 test, will hang
@main-horse
Copy link
Author

main-horse commented Dec 16, 2024

key learning facts from this:

  • zero3 as implemented in pytorch fsdp && microsoft deepspeed will naturally deadlock on rank-conditional execution
  • zero2 will not deadlock in deepspeed, because it does param sync in optim step
  • zero2 will deadlock in pytorch fsdp. please see [blog tbd]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment