Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active November 12, 2021 16:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vadimkantorov/67fe785ed0bf31727af29a3584b87be1 to your computer and use it in GitHub Desktop.
Save vadimkantorov/67fe785ed0bf31727af29a3584b87be1 to your computer and use it in GitHub Desktop.
Mini-batching within the model in PyTorch
# Instance-Aware, Context-Focused, and Memory-Efficient Weakly Supervised Object Detection, https://arxiv.org/abs/2004.04725
# https://github.com/NVlabs/wetectron/issues/72
# https://discuss.pytorch.org/t/mini-batching-gradient-accumulation-within-the-model/136460
import torch
import torch.nn as nn
class SequentialBackprop(nn.Module):
def __init__(self, module, batch_size = 1):
super().__init__()
self.module = module
self.batch_size = batch_size
def forward(self, x):
y = self.module(x.detach())
return self.Function.apply(x, y, self.batch_size, self.module)
class Function(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y, batch_size, module):
ctx.save_for_backward(x)
ctx.batch_size = batch_size
ctx.module = module
return y
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
grads = []
for x_mini, g_mini in zip(x.split(ctx.batch_size), grad_output.split(ctx.batch_size)):
with torch.enable_grad():
x_mini = x_mini.detach().requires_grad_()
x_mini.retain_grad()
y_mini = ctx.module(x_mini)
torch.autograd.backward(y_mini, g_mini)
grads.append(x_mini.grad)
return torch.cat(grads), None, None, None
if __name__ == '__main__':
backbone = nn.Linear(3, 6)
neck = nn.Linear(6, 12)
head = nn.Linear(12, 1)
model = nn.Sequential(backbone, SequentialBackprop(neck, batch_size = 16), head)
print('before', neck.weight.grad)
x = torch.rand(512, 3)
model(x).sum().backward()
print('after', neck.weight.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment