-
-
Save main-horse/bae3a07643e204a48b926870deeef804 to your computer and use it in GitHub Desktop.
showcasing the deadlock behavior of zero2 vs zero3 on deepspeed vs fsdp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# uv pip install deepspeed torch transformers | |
import os | |
import torch | |
import deepspeed | |
from transformers import AutoModelForCausalLM | |
def create_config(zero_stage=2): | |
return { | |
"train_batch_size": 2, # global_batch_size = micro_batch_size * grad_acc * dp_size | |
"train_micro_batch_size_per_gpu": 1, | |
"gradient_accumulation_steps": 1, | |
"steps_per_print": 1, | |
"zero_optimization": { | |
"stage": zero_stage, | |
"overlap_comm": True, | |
"contiguous_gradients": True, | |
"reduce_bucket_size": 5e7, | |
"stage3_prefetch_bucket_size": 5e7, | |
"stage3_param_persistence_threshold": 1e5 | |
}, | |
"fp16": {"enabled": True, "initial_scale_power": 12}, | |
"gradient_clipping": 1.0, | |
"prescale_gradients": False, | |
"wall_clock_breakdown": False | |
} | |
def main(): | |
zero_stage = int(os.getenv("ZEROSTAGE", "2")) | |
local_rank = int(os.getenv('LOCAL_RANK', '0')) | |
# Initialize model | |
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m", torch_dtype=torch.float16) | |
engine, _, _, _ = deepspeed.initialize( | |
model=model, | |
optimizer=torch.optim.AdamW(model.parameters(), lr=1e-5), | |
config=create_config(zero_stage), | |
model_parameters=model.parameters() | |
) | |
# Create dummy input and target | |
input_ids = torch.randint(0, 1000, (1, 32), device=local_rank) | |
labels = torch.randint(0, 1000, (1, 32), device=local_rank) | |
# Perform initial forward-backward pass on all ranks | |
print(f"Performing initial forward-backward pass on rank {local_rank}") | |
outputs = engine(input_ids=input_ids, labels=labels) | |
engine.backward(outputs.loss) | |
engine.step() | |
print(f"Completed initial forward-backward pass on rank {local_rank}") | |
torch.distributed.barrier() | |
# Only run on rank 0 | |
if local_rank == 0: | |
print(f"Running inference on rank {local_rank}") | |
with torch.no_grad(): | |
_ = engine(input_ids=input_ids) | |
print("First forward pass completed") | |
_ = engine(input_ids=input_ids) | |
print("Second forward pass completed") | |
# Check parameter status | |
for n, p in engine.module.named_parameters(): | |
if hasattr(p, 'ds_status'): | |
print(f"Parameter {n}: Status = {p.ds_status}") | |
if local_rank == 0: print("Test completed successfully") | |
if __name__ == "__main__": main() | |
# ZEROSTAGE=2 deepspeed --num_gpus=2 test_ds.py # <-- won't deadlock | |
# ZEROSTAGE=3 deepspeed --num_gpus=2 test_ds.py # <-- will deadlock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 | |
import os | |
import sys | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.multiprocessing as mp | |
from torch.distributed.tensor import DeviceMesh | |
# do not use torch<=2.5.1 or this will fail: | |
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, FullyShardedDataParallel, MixedPrecision, ShardingStrategy | |
class SimpleModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.fc1 = nn.Linear(10, 10) | |
self.fc2 = nn.Linear(10, 10) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.fc2(x) | |
return x | |
def setup(rank, world_size): | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '29500' | |
dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
torch.cuda.set_device(rank) | |
def cleanup(): | |
dist.destroy_process_group() | |
def run_demo(rank, world_size): | |
print(f"Running on rank {rank}") | |
setup(rank, world_size) | |
# Create model and move to GPU | |
model = SimpleModel().cuda(rank) | |
# Create 2D mesh for FSDP | |
mesh = DeviceMesh("cuda", torch.arange(world_size)) | |
# Wrap model with FSDP, explicitly setting reshard_after_forward=False | |
model = fully_shard( | |
model, | |
mesh=mesh, | |
reshard_after_forward=False, | |
mp_policy=MixedPrecisionPolicy( | |
param_dtype=torch.float16, | |
reduce_dtype=torch.float16, | |
) | |
) if len(sys.argv) > 1 else FullyShardedDataParallel( | |
model, | |
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, | |
mixed_precision=MixedPrecision(torch.float16, torch.float16), | |
) | |
# Test regular forward/backward on all ranks | |
batch = torch.randn(20, 10).cuda(rank) | |
output = model(batch) | |
loss = output.sum() | |
loss.backward() | |
print(f"Rank {rank}: Completed regular forward/backward") | |
dist.barrier() | |
print(f"Rank {rank}: Passed first barrier") | |
# Now test single-rank forward | |
if rank == 0: | |
print("Rank 0: Starting single-rank forward") | |
single_batch = torch.randn(20, 10).cuda(rank) | |
with torch.no_grad(): | |
single_output = model(single_batch) | |
print("Rank 0: Completed single-rank forward") | |
dist.barrier() | |
print(f"Rank {rank}: Passed final barrier") | |
cleanup() | |
if __name__ == "__main__": | |
mp.spawn(run_demo, args=(2,), nprocs=2, join=True) | |
# torchrun --nproc-per-node 2 test_fs.py 2 # <-- fsdp2 test, will hang | |
# torchrun --nproc-per-node 2 test_fs.py # <-- fsdp1 test, will hang |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
key learning facts from this: