Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created September 2, 2022 05:07
Show Gist options
  • Save ptrblck/40ac3188f1676b2dc4a1525d747a6a4e to your computer and use it in GitHub Desktop.
Save ptrblck/40ac3188f1676b2dc4a1525d747a6a4e to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
# setup
emb1 = nn.Embedding(4, 4)
opt1 = torch.optim.Adam(emb1.parameters(), lr=1.)
emb2 = nn.Embedding(4, 4, sparse=True)
emb2.load_state_dict(emb1.state_dict())
opt2 = torch.optim.SparseAdam(emb2.parameters(), lr=1.)
# 1st update
x = torch.tensor([0, 2])
out1 = emb1(x)
out1.mean().backward()
# gradiets at expected indices
print(emb1.weight.grad)
opt1.step()
opt1.zero_grad()
out2 = emb2(x)
out2.mean().backward()
# gradiets at expected indices
print(emb2.weight.grad)
opt2.step()
opt2.zero_grad()
# compare
print((emb1.weight - emb2.weight).abs().mean(1))
# tensor([2.3544e-06, 0.0000e+00, 2.3544e-06, 0.0000e+00],
# grad_fn=<MeanBackward1>)
# small abs differences due to limited floating point precision, but the results are equal
# 2nd update at new index
x = torch.tensor([1])
out1 = emb1(x)
out1.mean().backward()
# gradient at expected index
print(emb1.weight.grad)
opt1.step()
opt1.zero_grad()
out2 = emb2(x)
out2.mean().backward()
# gradient at expected index
print(emb2.weight.grad)
opt2.step()
opt2.zero_grad()
# compare
print((emb1.weight - emb2.weight).abs().mean(1))
# tensor([6.7006e-01, 9.5367e-07, 6.7006e-01, 0.0000e+00],
# grad_fn=<MeanBackward1>)
# difference now at index 0 and 2 since `Adam` updated it via its running stats
# fake updates
w1 = emb1.weight.clone()
print(emb1.weight - w1)
for _ in range(3):
# updates it even though the grad is zero
opt1.step()
print(emb1.weight - w1)
w2 = emb2.weight.clone()
print(emb2.weight - w2)
for _ in range(3):
# no updates
opt2.step()
print(emb2.weight - w2)
# now let's use set_to_none=True
opt1.zero_grad(set_to_none=True)
w1 = emb1.weight.clone()
print(emb1.weight - w1)
for _ in range(3):
# no updates anymore
opt1.step()
print(emb1.weight - w1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment