Skip to content

Instantly share code, notes, and snippets.

@zengyu714
Last active May 11, 2017 10:07
Show Gist options
  • Save zengyu714/044cab15e3607e6a89805ec7b36d7314 to your computer and use it in GitHub Desktop.
Save zengyu714/044cab15e3607e6a89805ec7b36d7314 to your computer and use it in GitHub Desktop.
Pytorch
# Fintune models while add some new modules.
class Customize(nn.Module):
def __init__(self, pre_model):
"""Load the pretrained model, replace the last fc layers and add some new layers."""
super(Customize, self).__init__()
self.features = pre_model
# If freeze previous weights
# -------------------------------------------
# for param in self.features.parameters():
# param.requires_grad = False
# -------------------------------------------
# Set 1 channel for gray images
self.features.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Modify by suitable kernel numebrs
self.features.fc = nn.Linear(self.features.fc.in_features, 500)
# Add new modules here
# -------------------------------------------------------
self.new_fc = nn.Sequential(
torch.nn.Linear(300, 100), torch.nn.ReLU(),
torch.nn.Linear(100, 1)
)
# -------------------------------------------------------
self.init_weights()
def init_weights(self):
"""Initialize the weights."""
self.features.fc.weight.data.normal_(0.0, 0.02)
self.features.fc.bias.data.fill_(0.01)
def forward(self, inputs):
"""Extract the image feature vectors."""
output = self.new_fc(self.features(inputs))
return output
# Initialize full-connected layers
from torch.nn import init
for m in model.modules():
if isinstance(m, torch.nn.Linear):
init.kaiming_normal(m.weight)
init.constant(m.bias, 0.01)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment