Skip to content

Instantly share code, notes, and snippets.

@oscarknagg
Last active November 25, 2023 03:52
Show Gist options
  • Save oscarknagg/45b187c236c6262b1c4bbe2d0920ded6 to your computer and use it in GitHub Desktop.
Save oscarknagg/45b187c236c6262b1c4bbe2d0920ded6 to your computer and use it in GitHub Desktop.
Gist for projected gradient descent adversarial attack using PyTorch
import torch
def projected_gradient_descent(model, x, y, loss_fn, num_steps, step_size, step_norm, eps, eps_norm,
clamp=(0,1), y_target=None):
"""Performs the projected gradient descent attack on a batch of images."""
x_adv = x.clone().detach().requires_grad_(True).to(x.device)
targeted = y_target is not None
num_channels = x.shape[1]
for i in range(num_steps):
_x_adv = x_adv.clone().detach().requires_grad_(True)
prediction = model(_x_adv)
loss = loss_fn(prediction, y_target if targeted else y)
loss.backward()
with torch.no_grad():
# Force the gradient step to be a fixed size in a certain norm
if step_norm == 'inf':
gradients = _x_adv.grad.sign() * step_size
else:
# Note .view() assumes batched image data as 4D tensor
gradients = _x_adv.grad * step_size / _x_adv.grad.view(_x_adv.shape[0], -1)\
.norm(step_norm, dim=-1)\
.view(-1, num_channels, 1, 1)
if targeted:
# Targeted: Gradient descent with on the loss of the (incorrect) target label
# w.r.t. the image data
x_adv -= gradients
else:
# Untargeted: Gradient ascent on the loss of the correct label w.r.t.
# the model parameters
x_adv += gradients
# Project back into l_norm ball and correct range
if eps_norm == 'inf':
# Workaround as PyTorch doesn't have elementwise clip
x_adv = torch.max(torch.min(x_adv, x + eps), x - eps)
else:
delta = x_adv - x
# Assume x and x_adv are batched tensors where the first dimension is
# a batch dimension
mask = delta.view(delta.shape[0], -1).norm(norm, dim=1) <= eps
scaling_factor = delta.view(delta.shape[0], -1).norm(norm, dim=1)
scaling_factor[mask] = eps
# .view() assumes batched images as a 4D Tensor
delta *= eps / scaling_factor.view(-1, 1, 1, 1)
x_adv = x + delta
x_adv = x_adv.clamp(*clamp)
return x_adv.detach()
@jaingeet
Copy link

What is y_target here?

@ananiask8
Copy link

ananiask8 commented Apr 13, 2019

I am right now doing something very similar, although my implementation is more closely based on Madry's. I get good results for normalized MNIST (mean = 0.1307, std = 0.3081). Of course, my clamp is different, with (min=-0.4242, max=2.8214). I wonder if my good results for the normalized MNIST are due to eps=0.3 needing to be rescaled. Also, I am using Adam and a different network architecture. I think it might be due to the vanishing gradient problem, since I am using sigmoid activations.

Could you tell me at how many epochs do you start going over 50%?

@ConcurrencyPractitioner

Oh, I just wanted to point one thing out. This code doesn't work if eps_norm is an integer value. You normalized the deltas in the last else branch using the variable norm which wasn't even defined anywhere. I think you meant eps_norm here. @oscarknagg Have you tested this intensively?

@kgautam01
Copy link

x_adv = x_adv.clamp(*clamp)
Can someone explain the need for this clamping?

@LeonidStarykh
Copy link

LeonidStarykh commented Jul 28, 2022

x_adv = x_adv.clamp(*clamp) Can someone explain the need for this clamping?

It's equivalent to x_adv = x_adv.clamp(0, 1).
* unpacks 'clamp'

@WangHexie
Copy link

Does the norm here mean eps_norm? (L45)
mask = delta.view(delta.shape[0], -1).norm(norm, dim=1) <= eps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment