Skip to content

Instantly share code, notes, and snippets.

@lukoshkin
Last active April 3, 2020 11:04
Show Gist options
  • Save lukoshkin/579426755c7a5a8b164b08d23534fbc5 to your computer and use it in GitHub Desktop.
Save lukoshkin/579426755c7a5a8b164b08d23534fbc5 to your computer and use it in GitHub Desktop.
Improving WGAN with the gradient penalty term
def calc_grad_penalty(real_samples, fake_samples, net_D):
"""
Evaluates gradient penalty for `net_D` and allows other gradients
to backpropogate through this penalty term
Args:
real_samples - a tensor (presumably, without `grad` attribute)
fake_samples - tensor of the same shape as `real_samples` tensor
net_D - a 'critic' which takes the input of the same shape
as `real_samples`
"""
alpha = real_samples.new(
real_samples.size(0),
*([1]*(real_samples.dim()-1))
).uniform_().expand(*real_samples.shape)
inputs = alpha * real_samples + (1-alpha) * fake_samples.detach()
inputs.requires_grad_(True)
outputs = net_D(inputs)
jacobian = torch.autograd.grad (
outputs=outputs,
inputs=inputs,
grad_outputs=torch.ones_like(outputs)
create_graph=True
)[0]
# flatten each sample grad. and apply 2nd norm to it
jacobian = jacobian.view(jacobian.size(0), -1)
return ((jacobian.norm(dim=1) - 1) ** 2).mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment