Skip to content

Instantly share code, notes, and snippets.

@maxidl
Created February 10, 2022 11:25
Show Gist options
  • Save maxidl/4cebf7b7e2a2de62f0699aff68193e68 to your computer and use it in GitHub Desktop.
Save maxidl/4cebf7b7e2a2de62f0699aff68193e68 to your computer and use it in GitHub Desktop.
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from pathlib import Path
from tqdm.auto import tqdm
print(torch.cuda.is_available())
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# dev = torch.device("cpu")
### finetune on cifar-10
batch_size = 40
learning_rate = 0.001
EPOCHS = 5 # change to 5 later on
WEIGHTS = Path('cifar10_weights.pt')
transform = transforms.Compose([
#transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
)
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
classes = train_dataset.classes
print(f'classes: {classes}\nnumber of instances:\n\ttrain: {len(train_dataset)}\n\tval: {len(val_dataset)}')
import matplotlib.pyplot as plt
def show_examples(n):
for i in range(n):
index = torch.randint(0, len(train_dataset), size=(1,))
image, target = train_dataset[index]
print(f'image of shape: {image.shape}')
print(f'label: {classes[target]}')
plt.imshow(image.permute(1,2,0).numpy())
plt.show()
# show_examples(4)
from torch.utils.data import DataLoader
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
import torch.nn as nn
import torch.nn.functional as F
def get_vgg_model():
vgg16 = torchvision.models.vgg16(pretrained = True)
input_lastLayer = vgg16.classifier[6].in_features
vgg16.classifier[5] = nn.Identity()
vgg16.classifier[6] = nn.Linear(input_lastLayer,10)
vgg16 = vgg16.to(dev)
return vgg16
vgg16_1 = get_vgg_model()
if not WEIGHTS.exists():
print(f'Could not find {WEIGHTS}, finetuning...')
optimizer = optim.SGD(vgg16_1.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
n_total_step = len(train_dl)
vgg16_1.train()
for epoch in range(EPOCHS):
for i, (imgs, labels) in enumerate(tqdm(train_dl, desc=f'Training epoch {epoch+1}')):
imgs, labels = imgs.to(dev), labels.to(dev)
outputs = vgg16_1(imgs)
n_correct = (outputs.argmax(axis=1)==labels).sum().item()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if (i+1) % 250 == 0:
print(f"epoch {epoch+1}/{EPOCHS}, step: {i+1}/{n_total_step}: loss = {loss:.5f}, acc = {100*(n_correct/labels.size(0)):.2f}%")
vgg16_1.eval()
with torch.no_grad():
n_correct = 0
n_samples = 0
for i, (imgs, labels) in enumerate(tqdm(val_dl, desc='Validation')):
imgs, labels = imgs.to(dev), labels.to(dev)
outputs = vgg16_1(imgs)
n_correct += (outputs.argmax(axis=1)==labels).sum().item()
n_samples += labels.size(0)
print(f"Validation accuracy {(n_correct / n_samples)*100}%")
torch.save(vgg16_1.state_dict(), WEIGHTS)
else:
print(f'Loading weights from {WEIGHTS}')
vgg16_1.load_state_dict(torch.load(WEIGHTS))
### Extension
TOPK = 100
LAMBDA = 1
LR = 1e-5
def get_simple_gradient_expl(model, image, target, absolute=False):
image.requires_grad = True
output = model(image)
grad = torch.autograd.grad(output[:, target], image, create_graph=True)[0] # create_graph=True for second order derivative
expl = grad.abs() if absolute else grad
return expl.sum(1).squeeze()
# # test expl generation
# expl = get_simple_gradient_expl(vgg16_1, train_dataset[0][0].unsqueeze(0).to(dev), target=predictions[0])
# assert expl.grad_fn
class CombinedDataset(torch.utils.data.Dataset):
def __init__(self, datasets):
super().__init__()
self.datasets = datasets
def __len__(self):
return len(self.datasets[0])
def __getitem__(self, idx):
return [d[idx] for d in self.datasets]
# optionally, take subset of training data
dataset = torch.utils.data.Subset(train_dataset, torch.arange(0, 200))
# dataset = train_dataset
dl = DataLoader(dataset, batch_size=batch_size) # no shuffle
# get predictions
vgg16_1.eval()
predictions = []
with torch.inference_mode():
for i, (imgs, labels) in enumerate(tqdm(dl,desc='getting predictions')):
outputs = vgg16_1(imgs.to(dev))
predictions.extend(outputs.argmax(1).tolist())
dl = DataLoader(CombinedDataset([dataset, predictions]), batch_size=batch_size) # no shuffle
# get explanations
expls_original = []
for i, ((imgs, labels), preds) in enumerate(tqdm(dl, desc='getting explanations')):
# break
imgs, preds = imgs.to(dev), preds.to(dev)
expl_batch = torch.stack([get_simple_gradient_expl(vgg16_1, imgs[i].unsqueeze(0), preds[i],True) for i in range(len(imgs))])
expls_original.extend([expl.detach() for expl in expl_batch])
topk_masks = []
for expl in expls_original:
topk_indices = expl.view(-1).argsort(descending=True)[:TOPK]
topk_mask = torch.zeros_like(expls_original[0]).long()
topk_mask = topk_mask.view(-1).scatter(0, topk_indices, 1).view(expls_original[0].shape)
topk_masks.append(topk_mask)
EPOCHS = 10
BATCH_SIZE=8
dl = DataLoader(CombinedDataset([dataset, predictions, expls_original, topk_masks]), batch_size=BATCH_SIZE, shuffle=True)
import copy
vgg16_2 = get_vgg_model()
vgg16_2.load_state_dict(copy.deepcopy(vgg16_1.state_dict()))
criterion = nn.CrossEntropyLoss()
optimizer2 = optim.Adam(vgg16_2.parameters(), lr=LR)
vgg16_2.train()
for epoch in range(EPOCHS):
total_losses = []
ce_losses = []
expl_losses = []
for i, ((imgs, labels), preds, expls, topk_masks) in enumerate(tqdm(dl, desc='manipulating vgg16_2')):
# break # use first batch only
# for i in range(1000): # some number of optimization steps
optimizer2.zero_grad()
imgs, labels, preds = imgs.to(dev), labels.to(dev), preds.to(dev)
output = vgg16_2(imgs)
ce_loss = criterion(output, labels)
# ce_loss = torch.tensor(0.0) # to test if optimizing only loss_expl works
fooled_expls = torch.stack([get_simple_gradient_expl(vgg16_2, imgs[i].unsqueeze(0), preds[i], True) for i in range(len(imgs))])
loss_expl = (fooled_expls * topk_masks).sum() / TOPK
loss_expl = LAMBDA * loss_expl
total_loss = ce_loss + loss_expl
total_loss.backward()
optimizer2.step()
total_losses.append(total_loss.item())
ce_losses.append(ce_loss.item())
expl_losses.append(loss_expl.item())
# print everage losses in this epoch
print(f'Epoch {epoch}\t \ttotal:{torch.tensor(total_losses).mean().item():.3f}\tce loss:{torch.tensor(ce_losses).mean().item():.3f}\texpl loss:{torch.tensor(expl_losses).mean().item():.3f}')
# get manipulated explanations
dl = DataLoader(CombinedDataset([dataset, predictions]), batch_size=batch_size) # no shuffle
expls_manipulated = []
for i, ((imgs, labels), preds) in enumerate(tqdm(dl, desc='getting explanations')):
# break
imgs, preds = imgs.to(dev), preds.to(dev)
expl_batch = torch.stack([get_simple_gradient_expl(vgg16_2, imgs[i].unsqueeze(0), preds[i],True) for i in range(len(imgs))])
expls_manipulated.extend([expl.detach() for expl in expl_batch])
# simple vis
vis_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())# no normalization
for i in range(10):
fig, axs = plt.subplots(1, 3)
axs[0].imshow(vis_dataset[i][0].permute(1,2,0))
axs[0].set_axis_off()
axs[1].imshow(expls_original[i].cpu())
axs[1].set_axis_off()
axs[2].imshow(expls_manipulated[i].cpu())
axs[2].set_axis_off()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment