Skip to content

Instantly share code, notes, and snippets.

@rowhanm
Last active September 27, 2022 00:05
Show Gist options
  • Save rowhanm/71272f157d8c9450d6b1c7639a612126 to your computer and use it in GitHub Desktop.
Save rowhanm/71272f157d8c9450d6b1c7639a612126 to your computer and use it in GitHub Desktop.
Reproduce fsdp optimizer state save bug
'''
python==3.7.5
pytorch==1.12.0
fairscale==0.4.6 (can't upgrade due to being restricted to python3.7)
'''
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.optim.grad_scaler import ShardedGradScaler
from torch.cuda.amp import autocast
from torchvision.models import resnet50
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12353'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def save_snapshot(model, optimizer, rank):
print(f"calling save on rank {torch.distributed.get_rank()}")
state = model.state_dict()
optim_state = model.gather_full_optim_state_dict(optimizer)
if rank == 0:
for k in state:
state[k] = state[k].cpu()
checkpoint = {
"state_dict": state,
"optimizer": optim_state,
}
torch.save(checkpoint, "checkpoint.ckpt")
print(f"Model saved on rank {rank}")
def load_snapshot(snapshot, model, optimizer):
torch.distributed.barrier()
checkpoint = torch.load(snapshot, map_location="cpu")
curr_opt_state_dict = checkpoint["optimizer"]
optim_shard_dict = model.get_shard_from_optim_state_dict(curr_opt_state_dict)
optimizer.load_state_dict(optim_shard_dict)
return model, optimizer
def demo_basic(rank, world_size):
print(f"Running basic FSDP example on rank {rank}.")
setup(rank, world_size)
torch.cuda.set_device(rank)
model = resnet50(num_classes=1000)
set_amp = False
model = FSDP(model, mixed_precision=set_amp)
model.to(rank)
model.train()
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-4)
scaler = ShardedGradScaler(enabled=set_amp)
criterion = torch.nn.CrossEntropyLoss().to(rank)
image = torch.rand((1,3,224,224)).to(rank)
target = torch.empty((1, 1000)).random_(2).to(rank)
for iter_idx in range(40):
model.zero_grad(set_to_none=True)
with autocast(enabled=set_amp):
preds = model(image)
loss = criterion(preds, target)
if set_amp:
scale = scaler.get_scale()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
if rank == 0:
print(loss, scale)
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if rank == 0:
print(loss)
del loss
if iter_idx == 20:
save_snapshot(model, optimizer, rank)
# uncomment to fix
# osd = optimizer.state_dict()
# for _, bufs in osd["state"].items():
# if "step" in bufs.keys():
# # convert state_step back from int into a singleton tensor
# bufs["step"] = torch.tensor(bufs["step"])
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_basic, world_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment