Skip to content

Instantly share code, notes, and snippets.

@xuwangyin
Last active June 14, 2023 00:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xuwangyin/642447139d409398f7d3bdb5e68c1cdc to your computer and use it in GitHub Desktop.
Save xuwangyin/642447139d409398f7d3bdb5e68c1cdc to your computer and use it in GitHub Desktop.
evaluate resnet50 models on imagenet using autoattack
import os
import argparse
from pathlib import Path
import warnings
import torch
import torchvision
import torch.nn as nn
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
from resnet50 import get_model
from tqdm import tqdm
import sys
sys.path.insert(0,'..')
class NormalizationWrapper(torch.nn.Module):
def __init__(self, model, mean, std):
super().__init__()
mean = mean[..., None, None]
std = std[..., None, None]
self.train(model.training)
self.model = model
self.register_buffer("mean", mean)
self.register_buffer("std", std)
def forward(self, x, *args, **kwargs):
x_normalized = (x - self.mean)/self.std
return self.model(x_normalized, *args, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.model.state_dict()
def IdentityWrapper(model):
mean = torch.tensor([0., 0., 0.])
std = torch.tensor([1., 1., 1.])
return NormalizationWrapper(model, mean, std)
def Cifar10Wrapper(model):
mean = torch.tensor([0.4913997551666284, 0.48215855929893703, 0.4465309133731618])
std = torch.tensor([0.24703225141799082, 0.24348516474564, 0.26158783926049628])
return NormalizationWrapper(model, mean, std)
def ImageNetWrapper(model):
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
return NormalizationWrapper(model, mean, std)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='/data/xuwang_yin/Projects/adversarial-corruptions/datasets/imagenet')
parser.add_argument('--norm', type=str, default='Linf')
parser.add_argument('--epsilon', type=float, default=8./255.)
parser.add_argument('--weights', type=str, default='/data/xuwang_yin/Projects/adversarial-corruptions/weights/imagenet/standard.pt')
parser.add_argument('--n_ex', type=int, default=1000)
parser.add_argument('--individual', action='store_true')
parser.add_argument('--save_dir', type=str, default='./results')
parser.add_argument('--batch_size', type=int, default=500)
parser.add_argument('--log_path', type=str, default='./log_file.txt')
parser.add_argument('--version', type=str, default='standard')
parser.add_argument('--state-path', type=Path, default=None)
args = parser.parse_args()
# load model
model = get_model(args.weights, num_classes=1000)
model = ImageNetWrapper(model)
model = nn.DataParallel(model)
model.cuda()
model.eval()
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
test_dataset = torchvision.datasets.ImageFolder(root=os.path.join(args.data_dir, 'val'),transform=test_transform)
test_loader = data.DataLoader(test_dataset, batch_size=1000, shuffle=True, num_workers=0, generator=torch.Generator().manual_seed(123))
# create save dir
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# load attack
from autoattack import AutoAttack
adversary = AutoAttack(model, norm=args.norm, eps=args.epsilon, log_path=args.log_path,
version=args.version)
x_test, y_test = [], []
for x, y in tqdm(test_loader):
x_test.append(x)
y_test.append(y)
x_test = torch.cat(x_test, dim=0)
y_test = torch.cat(y_test, dim=0)
args.n_ex = x_test.size(0)
with torch.no_grad():
out = model(x_test[:100].to('cuda'))
_, predicted = torch.max(out.to('cpu'), 1)
correct = (predicted == y_test[:100]).sum().item()
accuracy = correct / 100
print('accuracy for first 100 samples:', accuracy)
# example of custom version
if args.version == 'custom':
adversary.attacks_to_run = ['apgd-ce', 'fab']
adversary.apgd.n_restarts = 2
adversary.fab.n_restarts = 2
# run attack and save images
with torch.no_grad():
if not args.individual:
adv_complete = adversary.run_standard_evaluation(x_test[:args.n_ex], y_test[:args.n_ex],
bs=args.batch_size, state_path=args.state_path)
torch.save({'adv_complete': adv_complete}, '{}/{}_{}_1_{}_eps_{:.5f}.pth'.format(
args.save_dir, 'aa', args.version, adv_complete.shape[0], args.epsilon))
else:
# individual version, each attack is run on all test points
adv_complete = adversary.run_standard_evaluation_individual(x_test[:args.n_ex],
y_test[:args.n_ex], bs=args.batch_size)
torch.save(adv_complete, '{}/{}_{}_individual_1_{}_eps_{:.5f}_plus_{}_cheap_{}.pth'.format(
args.save_dir, 'aa', args.version, args.n_ex, args.epsilon))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment