Skip to content

Instantly share code, notes, and snippets.

@udithhaputhanthri
Created March 26, 2021 08:23
Show Gist options
  • Save udithhaputhanthri/d37d2ef46b6b2f7e91726cb6fa8c8867 to your computer and use it in GitHub Desktop.
Save udithhaputhanthri/d37d2ef46b6b2f7e91726cb6fa8c8867 to your computer and use it in GitHub Desktop.
WGAN_scripts
def get_gradient(crit, real_imgs, fake_imgs, epsilon):
mixed_imgs= real_imgs* epsilon + fake_imgs*(1- epsilon)
mixed_scores= crit(mixed_imgs)
gradient= torch.autograd.grad(outputs= mixed_scores,
inputs= mixed_imgs,
grad_outputs= torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True)[0]
return gradient
def gradient_penalty(gradient):
gradient= gradient.view(len(gradient), -1)
gradient_norm= gradient.norm(2, dim=1)
penalty = torch.nn.MSELoss()(gradient_norm, torch.ones_like(gradient_norm))
return penalty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment