Skip to content

Instantly share code, notes, and snippets.

@rohan-varma
Created April 13, 2021 04:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rohan-varma/df2a7eb677c81f5c7bb9557dd9222560 to your computer and use it in GitHub Desktop.
Save rohan-varma/df2a7eb677c81f5c7bb9557dd9222560 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.autograd import Function
class PassThrough(Function):
@staticmethod
def forward(ctx, *inputs):
return inputs
@staticmethod
def backward(ctx, *grad_outputs):
print(f"grad_outputs {grad_outputs}")
# Use gradients to search through the graph as in
# https://gist.github.com/rohan-varma/7c8dab3635193c04c607e67c4951f519
return grad_outputs
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(1, 1, bias=False)
self.b = nn.Linear(1, 1, bias=False)
def forward(self, x):
a, b = self.a(x), self.b(x)
# Get tensors from tuple. This would be a more general call to
# _find_tensors.
ret = a, b
new_a, new_b = PassThrough.apply(a, b)
# Reconstruct tuple from output tensors. This would require a more general
# function that repacks the tensor(s) into the data structure.
ret = new_a, new_b
return ret
model = MyModel()
inp = torch.ones(1)
out = model(inp)
# loss = out[0] + out[1]
loss = out[0].sum()
print("Calling backward...")
loss.backward()
print("Done with bwd")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment