Skip to content

Instantly share code, notes, and snippets.

Last active October 9, 2017 15:55
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
Star You must be signed in to star a gist
What would you like to do?
import torch
from torch.autograd import Variable
leaves = [Variable(torch.zeros(5, 5), requires_grad=True) for _ in range(10)]
intermediates = [l + i for i, l in enumerate(leaves)]
loss = sum(v * i for i, v in enumerate(intermediates)).sum()
# define a helper for dividing intermediates into groups
def group(l, group_size):
"""Groups l into chunks of size group_size.
E.g. group([1, 2, 3, 4, 5], 2) -> [[1, 2], [3, 4], [5]]
return (l[i:i + group_size] for i in range(0, len(l), group_size))
# Compute the d loss / d intermediates in chunks of shard_size
shard_size = 2
d_intermediates = [d_i for intermediates_batch in group(intermediates, shard_size)
for d_i in torch.autograd.grad(loss, intermediates_batch)]
# Compute rest of backward pass
torch.autograd.backward(intermediates, d_intermediates)
for i, l in enumerate(leaves):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment