Last active
September 27, 2022 00:05
-
-
Save rowhanm/71272f157d8c9450d6b1c7639a612126 to your computer and use it in GitHub Desktop.
Reproduce fsdp optimizer state save bug
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
python==3.7.5 | |
pytorch==1.12.0 | |
fairscale==0.4.6 (can't upgrade due to being restricted to python3.7) | |
''' | |
import os | |
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP | |
from fairscale.optim.grad_scaler import ShardedGradScaler | |
from torch.cuda.amp import autocast | |
from torchvision.models import resnet50 | |
def setup(rank, world_size): | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '12353' | |
# initialize the process group | |
dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
def cleanup(): | |
dist.destroy_process_group() | |
def save_snapshot(model, optimizer, rank): | |
print(f"calling save on rank {torch.distributed.get_rank()}") | |
state = model.state_dict() | |
optim_state = model.gather_full_optim_state_dict(optimizer) | |
if rank == 0: | |
for k in state: | |
state[k] = state[k].cpu() | |
checkpoint = { | |
"state_dict": state, | |
"optimizer": optim_state, | |
} | |
torch.save(checkpoint, "checkpoint.ckpt") | |
print(f"Model saved on rank {rank}") | |
def load_snapshot(snapshot, model, optimizer): | |
torch.distributed.barrier() | |
checkpoint = torch.load(snapshot, map_location="cpu") | |
curr_opt_state_dict = checkpoint["optimizer"] | |
optim_shard_dict = model.get_shard_from_optim_state_dict(curr_opt_state_dict) | |
optimizer.load_state_dict(optim_shard_dict) | |
return model, optimizer | |
def demo_basic(rank, world_size): | |
print(f"Running basic FSDP example on rank {rank}.") | |
setup(rank, world_size) | |
torch.cuda.set_device(rank) | |
model = resnet50(num_classes=1000) | |
set_amp = False | |
model = FSDP(model, mixed_precision=set_amp) | |
model.to(rank) | |
model.train() | |
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-4) | |
scaler = ShardedGradScaler(enabled=set_amp) | |
criterion = torch.nn.CrossEntropyLoss().to(rank) | |
image = torch.rand((1,3,224,224)).to(rank) | |
target = torch.empty((1, 1000)).random_(2).to(rank) | |
for iter_idx in range(40): | |
model.zero_grad(set_to_none=True) | |
with autocast(enabled=set_amp): | |
preds = model(image) | |
loss = criterion(preds, target) | |
if set_amp: | |
scale = scaler.get_scale() | |
scaler.scale(loss).backward() | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
scaler.step(optimizer) | |
scaler.update() | |
if rank == 0: | |
print(loss, scale) | |
else: | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
if rank == 0: | |
print(loss) | |
del loss | |
if iter_idx == 20: | |
save_snapshot(model, optimizer, rank) | |
# uncomment to fix | |
# osd = optimizer.state_dict() | |
# for _, bufs in osd["state"].items(): | |
# if "step" in bufs.keys(): | |
# # convert state_step back from int into a singleton tensor | |
# bufs["step"] = torch.tensor(bufs["step"]) | |
cleanup() | |
def run_demo(demo_fn, world_size): | |
mp.spawn(demo_fn, | |
args=(world_size,), | |
nprocs=world_size, | |
join=True) | |
if __name__ == "__main__": | |
n_gpus = torch.cuda.device_count() | |
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" | |
world_size = n_gpus | |
run_demo(demo_basic, world_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment