Skip to content

Instantly share code, notes, and snippets.

@soravux
Last active September 26, 2017 04:03
Show Gist options
  • Save soravux/d6622e2ca367d5dc5d1eab3de2cf3f0e to your computer and use it in GitHub Desktop.
Save soravux/d6622e2ca367d5dc5d1eab3de2cf3f0e to your computer and use it in GitHub Desktop.
Test pour la stabilité des exemples adversariaux.
from __future__ import print_function
import os
from tqdm import tqdm
import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from matplotlib import pyplot as plt
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=2, metavar='N',
help='number of epochs to train (default: 5)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
model = Net()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
# 1. Add requires_grad so Torch doesn't erase the gradient with its optimization pass
data, target = Variable(data, requires_grad=True), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))
def getAdversarials(data, target_idx):
target = torch.LongTensor(64)
target.fill_(target_idx)
target = Variable(target)
adv_data = Variable(data.data, requires_grad=True)
for _ in tqdm(range(1000)):
optimizer.zero_grad()
adv_data = Variable(adv_data.data, requires_grad=True)
output = model(adv_data)
loss = F.nll_loss(output, target)
loss.backward()
adv_data = adv_data - adv_data.grad
maxval_adv, idx = torch.max(output, 1)
print("Nombre de la batch qui sont maintenant des {}: {}".format(target_idx, np.sum(idx.data.numpy() == target_idx)))
return adv_data
def checkNoiseSensitivity(data, targets, sigma_noise=0.25):
output = model(data)
maxval, idx = torch.max(output, 1)
# Bruitons ca un peu
data_bruite = Variable(data.data + sigma_noise*torch.randn(data.size()))
output_bruite = model(data_bruite)
maxval_bruite, idx_bruite = torch.max(output_bruite, 1)
plt.figure()
plt.subplot(221); plt.imshow(data.data.numpy()[0,0,...]); plt.colorbar()
plt.subplot(222); plt.imshow(data_bruite.data.numpy()[0,0,...]); plt.colorbar()
plt.subplot(223); plt.imshow(data.data.numpy()[1,0,...]); plt.colorbar()
plt.subplot(224); plt.imshow(data_bruite.data.numpy()[1,0,...]); plt.colorbar()
if isinstance(targets, int):
targets = np.zeros(data.data.numpy().shape[0]) + targets
else:
targets = targets.numpy()
maxval = maxval.data.numpy()
maxval_bruite = maxval_bruite.data.numpy()
idx_bruite = idx_bruite.data.numpy()
idx = idx.data.numpy()
acc_normal = np.sum(idx == targets) / idx.size
acc_bruite = np.sum(idx_bruite == targets) / idx_bruite.size
print("acc. normal: {:.2f}%".format(acc_normal*100))
print("acc. bruite: {:.2f}%".format(acc_bruite*100))
plt.figure()
plt.subplot(211); plt.hist(maxval, 50)
plt.subplot(212); plt.hist(maxval_bruite, 50)
plt.show(block=False)
def test():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
if __name__ == "__main__":
model_filename = "ex2model.mesouviensplusdelextension"
if os.path.isfile(model_filename):
model.load_state_dict(torch.load(model_filename))
else:
for epoch in range(1, args.epochs + 1):
train(epoch)
torch.save(model.state_dict(), model_filename)
for data, target in train_loader:
break
data = Variable(data)
adversarials = getAdversarials(data, 1) # On veut qu'ils se trompent pour des "1"
print("Données standard")
checkNoiseSensitivity(data, target)
print("Données adversariales")
checkNoiseSensitivity(adversarials, 1)
plt.figure()
plt.subplot(321); plt.imshow(data.data.numpy()[0,0,...]); plt.colorbar()
plt.subplot(322); plt.imshow(adversarials.data.numpy()[0,0,...]); plt.colorbar()
plt.subplot(323); plt.imshow(data.data.numpy()[1,0,...]); plt.colorbar()
plt.subplot(324); plt.imshow(adversarials.data.numpy()[1,0,...]); plt.colorbar()
plt.subplot(325); plt.imshow(data.data.numpy()[2,0,...]); plt.colorbar()
plt.subplot(326); plt.imshow(adversarials.data.numpy()[2,0,...]); plt.colorbar()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment