Skip to content

Instantly share code, notes, and snippets.

@apaszke
Last active October 9, 2017 15:55
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save apaszke/a8e521ae41c2580530355b89472d53d8 to your computer and use it in GitHub Desktop.
Save apaszke/a8e521ae41c2580530355b89472d53d8 to your computer and use it in GitHub Desktop.
import torch
from torch.autograd import Variable
leaves = [Variable(torch.zeros(5, 5), requires_grad=True) for _ in range(10)]
intermediates = [l + i for i, l in enumerate(leaves)]
loss = sum(v * i for i, v in enumerate(intermediates)).sum()
# define a helper for dividing intermediates into groups
def group(l, group_size):
"""Groups l into chunks of size group_size.
E.g. group([1, 2, 3, 4, 5], 2) -> [[1, 2], [3, 4], [5]]
"""
return (l[i:i + group_size] for i in range(0, len(l), group_size))
# Compute the d loss / d intermediates in chunks of shard_size
shard_size = 2
d_intermediates = [d_i for intermediates_batch in group(intermediates, shard_size)
for d_i in torch.autograd.grad(loss, intermediates_batch)]
# Compute rest of backward pass
torch.autograd.backward(intermediates, d_intermediates)
for i, l in enumerate(leaves):
assert l.grad.data.eq(i).all()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment