Skip to content

Instantly share code, notes, and snippets.

@RomanSteinberg
Created December 12, 2018 12:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RomanSteinberg/9ca64be01ff8c8d02a225bd56c41fb5d to your computer and use it in GitHub Desktop.
Save RomanSteinberg/9ca64be01ff8c8d02a225bd56c41fb5d to your computer and use it in GitHub Desktop.
Freeing buffers strange behavior
# Description:
# This script is a minimal example of a freeing buffer strange behavior. Originally it contains error diagnosed
# by PyTorch:
# "RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.
# Specify retain_graph=True when calling backward the first time."
#
# One can find statements which can be changed to remove error.
import torch
from torch import nn, cuda
from torch.autograd import Variable, grad
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size()[0], -1)
class BrokenBlock(nn.Module):
def __init__(self, dim):
super(BrokenBlock, self).__init__()
self.conv_block = nn.Sequential(*[nn.InstanceNorm2d(dim, affine=False),
nn.ReLU(inplace=True)]) # change inplace=False and error disappears
def forward(self, x):
return self.conv_block(x)
class Di(nn.Module):
def __init__(self, input_shape):
super(Di, self).__init__()
input_nc, h, w = input_shape
sequence = [BrokenBlock(input_nc),
BrokenBlock(input_nc), # comment this line and error disappears
Flatten(),
nn.Linear(input_nc * h * w, 1)]
self.model = nn.Sequential(*sequence)
def forward(self, input, parallel_mode=True):
if isinstance(input.data, cuda.FloatTensor) and parallel_mode:
return nn.parallel.data_parallel(self.model, input, [0])
else:
return self.model(input)
class BrokenModel():
def __init__(self):
self.batch_size = 1
self.netD_A = Di((3, 256, 256)).cuda(device=0)
self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters())
self.real_A = Variable(torch.ones((1, 3, 256, 256)).cuda())
self.real_B = Variable(torch.ones((1, 3, 256, 256)).cuda())
def backward_D_wgan(self, netD, real, fake):
loss_D_real = -netD.forward(real, parallel_mode=True).mean() # change parallel_mode=False and error disappears
loss_D_fake = netD.forward(fake, parallel_mode=True).mean() # change parallel_mode=False and error disappears
gradient_penalty = self.__calc_gradient_penalty(netD, real, fake)
# separate Di backward pass for 3 parts
gradient_penalty.backward() # comment this line and error disappears
loss_D_real.backward()
loss_D_fake.backward()
loss_D = loss_D_fake - loss_D_real + gradient_penalty
return loss_D
def __calc_gradient_penalty(self, netD, real, fake):
alpha = torch.rand(self.batch_size, 1, 1, 1).expand_as(real).cuda()
interpolated = alpha * real.data + (1 - alpha) * fake.data
interpolated = Variable(interpolated, requires_grad=True).cuda()
# Calculate probability of interpolated examples
prob_interpolated = netD.forward(interpolated, parallel_mode=True) # change parallel_mode=False and error disappears
# Calculate gradients of probabilities with respect to examples
gradients = grad(outputs=prob_interpolated, inputs=interpolated,
grad_outputs=torch.ones(prob_interpolated.size()).cuda(),
create_graph=True, retain_graph=True)[0]
gradients_flatten = gradients.view(self.batch_size, -1)
gradients_norm = torch.sqrt(torch.sum(gradients_flatten ** 2, dim=1) + 1e-12)
return ((gradients_norm - 1) ** 2).mean()
def optimize_parameters(self):
self.optimizer_D_A.zero_grad()
self.backward_D_wgan(self.netD_A, self.real_A, self.real_B)
self.optimizer_D_A.step()
model = BrokenModel()
model.optimize_parameters()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment