Skip to content

Instantly share code, notes, and snippets.

@crcrpar
Created October 6, 2017 11:11
Show Gist options
  • Save crcrpar/6def91607b39aa7e491e17e6e84d95ec to your computer and use it in GitHub Desktop.
Save crcrpar/6def91607b39aa7e491e17e6e84d95ec to your computer and use it in GitHub Desktop.
wrapper for pretrained vgg19 in pytorch. you can get any hidden feature.
import torch
import torch.nn as nn
from torchvision.models import vgg19
"""
sequential (
# block_1
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU (inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU (inplace)
(4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
# block_2
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU (inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU (inplace)
(9): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
# block_3
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU (inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU (inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU (inplace)
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU (inplace)
(18): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
# block_4
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU (inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU (inplace)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU (inplace)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU (inplace)
(27): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
# block_5
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU (inplace)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU (inplace)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU (inplace)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU (inplace)
(36): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
)
"""
class Vgg19(torch.nn.Module):
"""VGG19."""
def __init__(self, feature=1, pretrained=True, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = vgg19(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
self.feature = feature
def get_relu(self, x, block_number):
layer_name = 'relu_{}_{}'.format(block_number, self.feature)
for i in range(1, block_number):
x = getattr(self, 'block_{}'.format(i))(x)
feature = getattr(self,
'block_{}'.format(block_number))(x, layer_name)
return feature
_vgg19_table = {
'0': 'conv_1_1',
'1': 'relu_1_1',
'2': 'conv_1_2',
'3': 'relu_1_2',
'4': 'max_pool_1',
'5': 'conv_2_1',
'6': 'relu_2_1',
'7': 'conv_2_2',
'8': 'relu_2_2',
'9': 'max_pool_2',
'10': 'conv_3_1',
'11': 'relu_3_1',
'12': 'conv_3_2',
'13': 'relu_3_2',
'14': 'conv_3_3',
'15': 'relu_3_3',
'16': 'conv_3_4',
'17': 'relu_3_4',
'18': 'max_pool_3',
'19': 'conv_4_1',
'20': 'relu_4_1',
'21': 'conv_4_2',
'22': 'relu_4_2',
'23': 'conv_4_3',
'24': 'relu_4_3',
'25': 'conv_4_4',
'26': 'relu_4_4',
'27': 'max_pool_4',
'28': 'conv_5_1',
'29': 'relu_5_1',
'30': 'conv_5_2',
'31': 'relu_5_2',
'32': 'conv_5_3',
'33': 'relu_5_3',
'34': 'conv_5_4',
'35': 'relu_5_4',
'36': 'max_pool_5',
}
_name_indxe_table = {}
for key, value in _vgg19_table.items():
_name_indxe_table[value] = key
_conv_list = ['conv_1_1', 'conv_1_2', 'conv_2_1', 'conv_2_2', 'conv_3_1',
'conv_3_2', 'conv_3_3', 'conv_3_4', 'conv_4_1', 'conv_4_2',
'conv_4_3', 'conv_4_4', 'conv_5_1', 'conv_5_2', 'conv_5_3',
'conv_5_4']
class VGG19(nn.Module):
def __init__(self, pretrained=True, requires_grad=False):
super(VGG19, self).__init__()
no_top_vgg19 = vgg19(pretrained=pretrained).features
self.relu = nn.ReLU(inplace=True)
self.mp = nn.MaxPool2d(kernel_size=2, stride=2, dilation=1)
for i in range(37):
layer = no_top_vgg19[i]
self.add_module(_vgg19_table[str(i)],
layer)
def get_feature(self, x, layer_name):
for idx, name in enumerate(_conv_list):
if idx in (2, 4, 8, 12):
x = self.mp(x)
x = getattr(self, name)(x)
if layer_name.startswith('conv') and layer_name == name:
return x
x = self.relu(x)
if name.replace('conv', 'relu') == layer_name:
return x
if __name__ == '__main__':
import numpy as np
from torch.autograd import Variable
test_vgg = VGG19(pretrained=False)
print('\n')
x = np.random.normal(size=(10, 3, 224, 224)).astype(np.float32)
x = torch.from_numpy(x)
x = Variable(x)
for i in range(1, 6):
layer = 'relu_{}_1'.format(i)
out = test_vgg.get_feature(x, layer_name=layer)
print(layer, out.size())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment