Skip to content

Instantly share code, notes, and snippets.

@snakers4
Created November 12, 2017 14:12
Show Gist options
  • Save snakers4/914ac6eb6911d5ebe73a1c41ea3a3966 to your computer and use it in GitHub Desktop.
Save snakers4/914ac6eb6911d5ebe73a1c41ea3a3966 to your computer and use it in GitHub Desktop.
import torchvision.models as models
import torch
import torch.nn as nn
class FineTuneModel(nn.Module):
def __init__(self,
original_model,
arch,
num_classes,
freeze
):
super(FineTuneModel, self).__init__()
if arch.startswith('vgg16'):
self.features = original_model.features
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(512 * 2 * 2 * 2, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
self.modelName = 'vgg16'
else :
raise("Finetuning not supported on this architecture yet")
# Freeze those weights
if freeze == True:
print('Core model layers are frozen')
for p in self.features.parameters():
p.requires_grad = False
def forward(self, x):
f1 = self.features(x[:,0:3,:,:])
f2 = self.features(x[:,3:6,:,:])
f = torch.cat((f1, f2), 1)
if self.modelName == 'vgg16':
f = f.view(f.size(0), -1)
y = self.classifier(f)
return y
original_model = models.__dict__['vgg16'](pretrained=True)
model = FineTuneModel(original_model,
'vgg16',
2,
False)
# Sample - we are feeding a sample of 6 channel images
# Can be easily rewritten to take 5 dimensional tensor
input = torch.autograd.Variable(torch.randn(1,6, 64, 64))
output = model(input)
output.size()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment