Skip to content

Instantly share code, notes, and snippets.

@min-xu-ai
Created September 16, 2022 03:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save min-xu-ai/407bb158f0d0612e157c2cbcf5c8b76a to your computer and use it in GitHub Desktop.
Save min-xu-ai/407bb158f0d0612e157c2cbcf5c8b76a to your computer and use it in GitHub Desktop.
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
import sys
import torch
from torch.nn import init
import math
torch.set_printoptions(precision=20)
torch.manual_seed(12)
print("RAND", torch.rand(1))
if sys.argv[1] == "0":
t = torch.rand(100, 100)
elif sys.argv[1] == "1":
t = torch.rand(100* 100 * 100)
t = t[:100*100].reshape(100, 100)
else:
assert 0
init.uniform_(t)
#init.kaiming_uniform_(t, a=math.sqrt(5))
print("NORM", t.norm())
print("RAND", torch.rand(1))
@min-xu-ai
Copy link
Author

cc @myleott, a very interesting behavior that affects FSDP in flattening case for weight init.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment