Created
May 18, 2021 21:58
-
-
Save rohan-varma/b2dad07dcc1f7f114ce039758613d3c5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# (c) Facebook, Inc. and its affiliates. Confidential and proprietary. | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Function | |
class PassThrough(Function): | |
@staticmethod | |
def forward(ctx, *inputs): | |
return inputs | |
@staticmethod | |
def backward(ctx, *grad_outputs): | |
print(f"grad_outputs {grad_outputs}") | |
return grad_outputs | |
class MyModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.a = nn.Linear(1, 1, bias=False) | |
self.b = nn.Linear(1, 1, bias=False) | |
def forward(self, x): | |
a, b = self.a(x), self.b(x) | |
ret = a, b | |
new_a, new_b = PassThrough.apply(a, b) | |
ret = new_a, new_b | |
return ret | |
model = MyModel() | |
def print_grads(): | |
for param_name, param in model.named_parameters(): | |
print(f"{param_name} : {param.grad}") | |
inp = torch.ones(1) | |
print("-- before backward ---") | |
model.zero_grad() | |
print_grads() | |
for _ in range(3): | |
model.zero_grad() | |
out = model(inp) | |
loss = out[0].sum() | |
print("Calling backward...") | |
loss.backward() | |
print("-- after bwd --") | |
print_grads() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment