Skip to content

Instantly share code, notes, and snippets.

@mfornet
Last active April 21, 2023 10:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mfornet/0144ceba3b3e21a4200700cc3b0e4f0b to your computer and use it in GitHub Desktop.
Save mfornet/0144ceba3b3e21a4200700cc3b0e4f0b to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
def build_model():
return nn.Sequential(
nn.Linear(1, 1, bias=False),
nn.Sigmoid(),
)
def acc_gradients_inplace():
torch.manual_seed(0)
print("Inplace gradients")
model = build_model()
print(f"Weight: {model[0].weight[0, 0]:.4}")
model.zero_grad()
for i in range(2):
inp = torch.randn((1, 1))
out = model(inp)
print(f"Step: {i} Input: {inp[0, 0]:.4} Output: {out[0, 0]:.4}")
out.backward()
print("Gradient:", model[0].weight.grad[0, 0])
print()
def acc_gradients_ext():
torch.manual_seed(0)
print("External gradients")
model = build_model()
print(f"Weight: {model[0].weight[0, 0]:.4}")
acc = torch.zeros_like(model[0].weight.data)
for i in range(2):
# Reset gradients
model.zero_grad()
inp = torch.randn((1, 1))
out = model(inp)
print(f"Step: {i} Input: {inp[0, 0]:.4} Output: {out[0, 0]:.4}")
out.backward()
# Accumulate gradients
acc += model[0].weight.grad
print("Gradient:", acc[0, 0])
print()
def acc_gradients_gold():
torch.manual_seed(0)
print("Correct gradients")
model = build_model()
print(f"Weight: {model[0].weight[0, 0]:.4}")
model.zero_grad()
loss = 0
for i in range(2):
inp = torch.randn((1, 1))
out = model(inp)
print(f"Step: {i} Input: {inp[0, 0]:.4} Output: {out[0, 0]:.4}")
loss += out
loss.backward()
print("Gradient:", model[0].weight.grad[0, 0])
print()
if __name__ == "__main__":
acc_gradients_inplace()
acc_gradients_ext()
acc_gradients_gold()
Inplace gradients
Weight: -0.007487
Step: 0 Input: 0.2072 Output: 0.4996
Step: 1 Input: 0.2699 Output: 0.4995
Gradient: tensor(0.1193)
External gradients
Weight: -0.007487
Step: 0 Input: 0.2072 Output: 0.4996
Step: 1 Input: 0.2699 Output: 0.4995
Gradient: tensor(0.1193)
Correct gradients
Weight: -0.007487
Step: 0 Input: 0.2072 Output: 0.4996
Step: 1 Input: 0.2699 Output: 0.4995
Gradient: tensor(0.1193)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment