Skip to content

Instantly share code, notes, and snippets.

@rohan-varma
Last active February 10, 2021 05:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rohan-varma/7c8dab3635193c04c607e67c4951f519 to your computer and use it in GitHub Desktop.
Save rohan-varma/7c8dab3635193c04c607e67c4951f519 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
def get_param_to_grad_accs(model):
param_to_grad_accs = {}
for param in model.parameters(recurse=True):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
param_to_grad_accs[param] = grad_acc
return param_to_grad_accs
def find_unused_params(model, loss):
"""
Given loss and model, finds list of params
that will not get gradient.
"""
param_to_grad_accs = get_param_to_grad_accs(model)
grad_accs = []
stack = [loss.grad_fn]
visited =set()
print(" -- Running DFS -- ")
while stack:
fn = stack.pop()
assert fn not in visited, f"Infinite loop: {fn}"
visited.add(fn)
next_fns = fn.next_functions
for next_fn in next_fns:
if next_fn[0] is not None:
# See if we found an accumulate grad
# print(next_fn[0])
if isinstance(next_fn[0], torch._C._functions.AccumulateGrad):
grad_accs.append(next_fn[0])
if next_fn[0] not in visited:
stack.append(next_fn[0])
#print(" --- Grad accs found --- ")
#print(grad_accs)
# Find unused parameters
# Parameter is unused if we did not DFS to its grad acc.
unused_parameters = []
for param, grad_acc_for_param in param_to_grad_accs.items():
if grad_acc_for_param not in grad_accs:
# print(f"param {param} unused in loss")
unused_parameters.append(param)
print(f"All unused {unused_parameters}")
# ---- DEMO ---
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(1,1,bias=False)
self.b = nn.Linear(1,1,bias=False)
def forward(self, x):
return (self.a(x), self.b(x))
model = Model()
inp = torch.ones(3, 1)
a, b = model(inp)
# loss = (a + b).sum()
# Note: B should be detected as unused
loss = (a).sum()
find_unused_params(model, loss)
expected_unused_param = list(model.b.parameters())[0]
print(f"expected {expected_unused_param} to be unused")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment