Skip to content

Instantly share code, notes, and snippets.

@leVirve
Last active August 17, 2017 06:52
Show Gist options
  • Save leVirve/7a2bf775095a40261002e64abcf4268e to your computer and use it in GitHub Desktop.
Save leVirve/7a2bf775095a40261002e64abcf4268e to your computer and use it in GitHub Desktop.
Override the behavior of `forward()` inside VGG from PyTorch torchvision.
import types
import torch
from torch.autograd import Variable
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets.folder import pil_loader
def sb_forward(self, x):
# conv1
x1 = self.features[0](x)
x1 = self.features[1](x1)
# conv2
x2 = self.features[2](x1)
x2 = self.features[3](x2)
# pool1
x3 = self.features[4](x2)
return x1, x2, x3
def extract(img=None):
img = img if img is not None else torch.randn((1, 3, 512, 512))
feature_maps = vgg(Variable(img, requires_grad=False))
return feature_maps
def numpy_feature_maps(feature_maps):
def tensor_data(x):
x = x.squeeze(0).permute(1, 2, 0)
data = x.data.numpy()
data = data / (data.max() - data.min())
return data
return [tensor_data(x) for x in feature_maps]
def make_input(img):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(img).unsqueeze(0)
def load_img(path):
return pil_loader(path)
# VGG model & replace the behavior of 'forward()'
vgg = models.vgg16(pretrained=True)
vgg.forward = types.MethodType(sb_forward, vgg)
import click
import numpy as np
from skimage.segmentation import slic, mark_boundaries
import matplotlib.pyplot as plt
from feature import extract, load_img, make_input, numpy_feature_maps
@click.command()
@click.option('--path', default='imgs/2white.jpg')
@click.option('-fe', '--feature_extract', is_flag=True)
@click.option('-n', '--n_segments', default=500)
@click.option('-c', '--compactness', default=0.1)
def main(path, feature_extract, n_segments, compactness):
image = load_img(path)
if feature_extract:
feature_maps = extract(make_input(image))
features = numpy_feature_maps(feature_maps)
np.save('features', features)
print('==> Feature extracted and saved.')
if not feature_extract:
features = np.load('features.npy')
print('==> Feature loaded from file.')
features = np.dstack(features[:2])
segments = slic(features, n_segments=n_segments, compactness=compactness)
img_show(image, segments)
def img_show(image, segments):
fig = plt.figure("Superpixels")
ax = fig.add_subplot(1, 1, 1)
ax.imshow(mark_boundaries(image, segments))
plt.axis("off")
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment