Skip to content

Instantly share code, notes, and snippets.

Last active June 7, 2019 21:59
Show Gist options
  • Save martinsotir/b51fc38e85cb728b1c187fc32c789e06 to your computer and use it in GitHub Desktop.
Save martinsotir/b51fc38e85cb728b1c187fc32c789e06 to your computer and use it in GitHub Desktop.
# Inspired from Keras and
from pathlib import Path
import torch
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
def total_variation_loss(x): # From:
B, C, H, W = x.size()
dh = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2).reshape(B, C, -1)
dv = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2).reshape(B, C, -1)
return torch.norm([dh, dv], dim=2), p=1) / (H*W)
class SaveFeatures():
def __init__(self, module):
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.features = torch.tensor(output, requires_grad=True).cuda()
def close(self):
def visualize_batch(model, module, nb_filters, lr=0.1, color=True, tv_reg=0.01, l2_reg=0.05, size=56, steps=10, max_pixel_value=1):
set_trainable(model, False)
activations = SaveFeatures(module)
img = torch.rand((nb_filters, 3 if color else 1, size, size), dtype=torch.float32,
requires_grad=True, device=torch.device('cuda'))
optimizer = torch.optim.Adam([img], lr=lr, weight_decay=0)
for n in range(steps):
loss = -torch.stack([activations.features[i, i].mean() +
l2_reg * torch.norm((img[i]-max_pixel_value/2)/(max_pixel_value/2), p=2)
for i in range(nb_filters)]).mean() + tv_reg * total_variation_loss(img)
return img.detach()
def plot_filter_ma(net, layer_name, layer_conv_name, lr=0.01, steps=100, color=color, tv_reg=0.01, l2_reg=0.01, size=80, max_pixel_value=1):
layer_module = dict(dict(dict(dict(net.named_children())['CCN']
nb_filters = dict(layer_module.named_children())['conv'].out_channels
img = visualize_batch(net, layer_module, nb_filters, lr=lr, steps=steps, color=color, tv_reg=tv_reg, l2_reg=l2_reg, size=size, max_pixel_value=max_pixel_value)
plt.figure(figsize=(15, 15))
plt.imshow(np.moveaxis(make_grid(img, normalize=True).cpu().numpy(), 0, 2)[:, :, slice(None ,None, -1) if color else 0])
net = ...
net = net.cuda().eval()
color = False
plot_filter_ma(net, 'base_layer', '0', lr=0.1, steps=300, color=color, tv_reg=0.0001, l2_reg=0.01, size=80, max_pixel_value=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment