Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rohan-varma/b2dad07dcc1f7f114ce039758613d3c5 to your computer and use it in GitHub Desktop.
Save rohan-varma/b2dad07dcc1f7f114ce039758613d3c5 to your computer and use it in GitHub Desktop.
# (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
import torch
import torch.nn as nn
from torch.autograd import Function
class PassThrough(Function):
@staticmethod
def forward(ctx, *inputs):
return inputs
@staticmethod
def backward(ctx, *grad_outputs):
print(f"grad_outputs {grad_outputs}")
return grad_outputs
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(1, 1, bias=False)
self.b = nn.Linear(1, 1, bias=False)
def forward(self, x):
a, b = self.a(x), self.b(x)
ret = a, b
new_a, new_b = PassThrough.apply(a, b)
ret = new_a, new_b
return ret
model = MyModel()
def print_grads():
for param_name, param in model.named_parameters():
print(f"{param_name} : {param.grad}")
inp = torch.ones(1)
print("-- before backward ---")
model.zero_grad()
print_grads()
for _ in range(3):
model.zero_grad()
out = model(inp)
loss = out[0].sum()
print("Calling backward...")
loss.backward()
print("-- after bwd --")
print_grads()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment