Instantly share code, notes, and snippets.
Created
October 6, 2017 11:11
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save crcrpar/6def91607b39aa7e491e17e6e84d95ec to your computer and use it in GitHub Desktop.
wrapper for pretrained vgg19 in pytorch. you can get any hidden feature.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torchvision.models import vgg19 | |
""" | |
sequential ( | |
# block_1 | |
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(1): ReLU (inplace) | |
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(3): ReLU (inplace) | |
(4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) | |
# block_2 | |
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(6): ReLU (inplace) | |
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(8): ReLU (inplace) | |
(9): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) | |
# block_3 | |
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(11): ReLU (inplace) | |
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(13): ReLU (inplace) | |
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(15): ReLU (inplace) | |
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(17): ReLU (inplace) | |
(18): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) | |
# block_4 | |
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(20): ReLU (inplace) | |
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(22): ReLU (inplace) | |
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(24): ReLU (inplace) | |
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(26): ReLU (inplace) | |
(27): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) | |
# block_5 | |
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(29): ReLU (inplace) | |
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(31): ReLU (inplace) | |
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(33): ReLU (inplace) | |
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(35): ReLU (inplace) | |
(36): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) | |
) | |
""" | |
class Vgg19(torch.nn.Module): | |
"""VGG19.""" | |
def __init__(self, feature=1, pretrained=True, requires_grad=False): | |
super(Vgg19, self).__init__() | |
vgg_pretrained_features = vgg19(pretrained=pretrained).features | |
self.slice1 = torch.nn.Sequential() | |
self.slice2 = torch.nn.Sequential() | |
self.slice3 = torch.nn.Sequential() | |
self.slice4 = torch.nn.Sequential() | |
for x in range(4): | |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(4, 9): | |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(9, 16): | |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(16, 23): | |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
if not requires_grad: | |
for param in self.parameters(): | |
param.requires_grad = False | |
self.feature = feature | |
def get_relu(self, x, block_number): | |
layer_name = 'relu_{}_{}'.format(block_number, self.feature) | |
for i in range(1, block_number): | |
x = getattr(self, 'block_{}'.format(i))(x) | |
feature = getattr(self, | |
'block_{}'.format(block_number))(x, layer_name) | |
return feature | |
_vgg19_table = { | |
'0': 'conv_1_1', | |
'1': 'relu_1_1', | |
'2': 'conv_1_2', | |
'3': 'relu_1_2', | |
'4': 'max_pool_1', | |
'5': 'conv_2_1', | |
'6': 'relu_2_1', | |
'7': 'conv_2_2', | |
'8': 'relu_2_2', | |
'9': 'max_pool_2', | |
'10': 'conv_3_1', | |
'11': 'relu_3_1', | |
'12': 'conv_3_2', | |
'13': 'relu_3_2', | |
'14': 'conv_3_3', | |
'15': 'relu_3_3', | |
'16': 'conv_3_4', | |
'17': 'relu_3_4', | |
'18': 'max_pool_3', | |
'19': 'conv_4_1', | |
'20': 'relu_4_1', | |
'21': 'conv_4_2', | |
'22': 'relu_4_2', | |
'23': 'conv_4_3', | |
'24': 'relu_4_3', | |
'25': 'conv_4_4', | |
'26': 'relu_4_4', | |
'27': 'max_pool_4', | |
'28': 'conv_5_1', | |
'29': 'relu_5_1', | |
'30': 'conv_5_2', | |
'31': 'relu_5_2', | |
'32': 'conv_5_3', | |
'33': 'relu_5_3', | |
'34': 'conv_5_4', | |
'35': 'relu_5_4', | |
'36': 'max_pool_5', | |
} | |
_name_indxe_table = {} | |
for key, value in _vgg19_table.items(): | |
_name_indxe_table[value] = key | |
_conv_list = ['conv_1_1', 'conv_1_2', 'conv_2_1', 'conv_2_2', 'conv_3_1', | |
'conv_3_2', 'conv_3_3', 'conv_3_4', 'conv_4_1', 'conv_4_2', | |
'conv_4_3', 'conv_4_4', 'conv_5_1', 'conv_5_2', 'conv_5_3', | |
'conv_5_4'] | |
class VGG19(nn.Module): | |
def __init__(self, pretrained=True, requires_grad=False): | |
super(VGG19, self).__init__() | |
no_top_vgg19 = vgg19(pretrained=pretrained).features | |
self.relu = nn.ReLU(inplace=True) | |
self.mp = nn.MaxPool2d(kernel_size=2, stride=2, dilation=1) | |
for i in range(37): | |
layer = no_top_vgg19[i] | |
self.add_module(_vgg19_table[str(i)], | |
layer) | |
def get_feature(self, x, layer_name): | |
for idx, name in enumerate(_conv_list): | |
if idx in (2, 4, 8, 12): | |
x = self.mp(x) | |
x = getattr(self, name)(x) | |
if layer_name.startswith('conv') and layer_name == name: | |
return x | |
x = self.relu(x) | |
if name.replace('conv', 'relu') == layer_name: | |
return x | |
if __name__ == '__main__': | |
import numpy as np | |
from torch.autograd import Variable | |
test_vgg = VGG19(pretrained=False) | |
print('\n') | |
x = np.random.normal(size=(10, 3, 224, 224)).astype(np.float32) | |
x = torch.from_numpy(x) | |
x = Variable(x) | |
for i in range(1, 6): | |
layer = 'relu_{}_1'.format(i) | |
out = test_vgg.get_feature(x, layer_name=layer) | |
print(layer, out.size()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment