Skip to content

Instantly share code, notes, and snippets.

@rohan-varma
Created May 18, 2021 03:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rohan-varma/a25598b451eb1030aba517c806639b3e to your computer and use it in GitHub Desktop.
Save rohan-varma/a25598b451eb1030aba517c806639b3e to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.autograd import Function
class PassThrough(Function):
@staticmethod
def forward(ctx, *inputs):
return inputs
@staticmethod
def backward(ctx, *grad_outputs):
print(f"grad_outputs in PassThrough backward {grad_outputs}")
return grad_outputs
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(1, 1, bias=False)
self.b = nn.Linear(1, 1, bias=False)
def forward(self, x):
a, b = self.a(x), self.b(x)
# Get tensors from tuple. This would be a more general call to
# _find_tensors.
ret = a, b
new_a, new_b = PassThrough.apply(a, b)
# Reconstruct tuple from output tensors. This would require a more general
# function that repacks the tensor(s) into the data structure.
ret = new_a, new_b
return ret
model = MyModel()
def print_grads():
for param_name, param in model.named_parameters():
print(f"{param_name} : {param.grad}")
inp = torch.ones(1)
print("-- before backward ---")
print_grads()
for _ in range(3):
model.zero_grad()
out = model(inp)
loss = out[0].sum()
print("Calling backward...")
loss.backward()
print("-- after bwd --")
print_grads()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment