Skip to content

Instantly share code, notes, and snippets.

@xuwangyin
Created August 9, 2023 23:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xuwangyin/aba5faae16011e168de96abb17b54017 to your computer and use it in GitHub Desktop.
Save xuwangyin/aba5faae16011e168de96abb17b54017 to your computer and use it in GitHub Desktop.
eval cifar10 wrn-28-10 robust model using autoattack, put the script into https://github.com/locuslab/robust_overfitting
import os
import argparse
from pathlib import Path
import warnings
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
from wideresnet import WideResNet
import sys
sys.path.insert(0,'./auto-attack')
class NormalizationWrapper(torch.nn.Module):
def __init__(self, model, mean, std):
super().__init__()
mean = mean[..., None, None]
std = std[..., None, None]
self.train(model.training)
self.model = model
self.register_buffer("mean", mean)
self.register_buffer("std", std)
def forward(self, x, *args, **kwargs):
x_normalized = (x - self.mean)/self.std
return self.model(x_normalized, *args, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.model.state_dict()
def IdentityWrapper(model):
mean = torch.tensor([0., 0., 0.])
std = torch.tensor([1., 1., 1.])
return NormalizationWrapper(model, mean, std)
def Cifar10Wrapper(model):
mean = torch.tensor([0.4913997551666284, 0.48215855929893703, 0.4465309133731618])
std = torch.tensor([0.24703225141799082, 0.24348516474564, 0.26158783926049628])
return NormalizationWrapper(model, mean, std)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='./data')
parser.add_argument('--norm', type=str, default='Linf')
parser.add_argument('--epsilon', type=float, default=8./255.)
parser.add_argument('--model', type=str, required=True)
parser.add_argument('--n_ex', type=int, default=1000)
parser.add_argument('--individual', action='store_true')
parser.add_argument('--save_dir', type=str, default='./results')
parser.add_argument('--batch_size', type=int, default=500)
parser.add_argument('--log_path', type=str, default='./log_file.txt')
parser.add_argument('--version', type=str, default='standard')
parser.add_argument('--state-path', type=Path, default=None)
args = parser.parse_args()
# load model
model = WideResNet(28, 10, widen_factor=10, dropRate=0.0)
ckpt = torch.load(args.model)['state_dict']
ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)
model = Cifar10Wrapper(model)
model = nn.DataParallel(model)
model.cuda()
model.eval()
# load data
transform_list = [transforms.ToTensor()]
transform_chain = transforms.Compose(transform_list)
item = datasets.CIFAR10(root=args.data_dir, train=False, transform=transform_chain, download=True)
test_loader = data.DataLoader(item, batch_size=1000, shuffle=False, num_workers=0)
# create save dir
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# load attack
from autoattack import AutoAttack
adversary = AutoAttack(model, norm=args.norm, eps=args.epsilon, log_path=args.log_path,
version=args.version)
l = [x for (x, y) in test_loader]
x_test = torch.cat(l, 0)
l = [y for (x, y) in test_loader]
y_test = torch.cat(l, 0)
# example of custom version
if args.version == 'custom':
adversary.attacks_to_run = ['apgd-ce', 'fab']
adversary.apgd.n_restarts = 2
adversary.fab.n_restarts = 2
# run attack and save images
with torch.no_grad():
if not args.individual:
adv_complete = adversary.run_standard_evaluation(x_test[:args.n_ex], y_test[:args.n_ex],
bs=args.batch_size, state_path=args.state_path)
torch.save({'adv_complete': adv_complete}, '{}/{}_{}_1_{}_eps_{:.5f}.pth'.format(
args.save_dir, 'aa', args.version, adv_complete.shape[0], args.epsilon))
else:
# individual version, each attack is run on all test points
adv_complete = adversary.run_standard_evaluation_individual(x_test[:args.n_ex],
y_test[:args.n_ex], bs=args.batch_size)
torch.save(adv_complete, '{}/{}_{}_individual_1_{}_eps_{:.5f}_plus_{}_cheap_{}.pth'.format(
args.save_dir, 'aa', args.version, args.n_ex, args.epsilon))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment