Last active
December 5, 2019 16:37
-
-
Save luistelmocosta/d0d48614e1a8b655a3aa56323060d84a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_upsampling_weight(in_channels, out_channels, kernel_size): | |
"""Make a 2D bilinear kernel suitable for upsampling""" | |
factor = (kernel_size + 1) // 2 | |
if kernel_size % 2 == 1: | |
center = factor - 1 | |
else: | |
center = factor - 0.5 | |
og = np.ogrid[:kernel_size, :kernel_size] | |
filt = (1 - abs(og[0] - center) / factor) * \ | |
(1 - abs(og[1] - center) / factor) | |
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), | |
dtype=np.float64) | |
weight[range(in_channels), range(out_channels), :, :] = filt | |
return torch.from_numpy(weight).float() | |
class FCN8s(nn.Module): | |
'''pretrained_model = \ | |
osp.expanduser('~/data/models/pytorch/fcn8s_from_caffe.pth') | |
@classmethod | |
def download(cls): | |
return fcn.data.cached_download( | |
url='http://drive.google.com/uc?id=0B9P1L--7Wd2vT0FtdThWREhjNkU', | |
path=cls.pretrained_model, | |
md5='dbd9bbb3829a3184913bccc74373afbb', | |
)''' | |
def __init__(self, n_class=21): | |
super(FCN8s, self).__init__() | |
# conv1 | |
self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) | |
self.relu1_1 = nn.ReLU(inplace=True) | |
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) | |
self.relu1_2 = nn.ReLU(inplace=True) | |
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 | |
# conv2 | |
self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) | |
self.relu2_1 = nn.ReLU(inplace=True) | |
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) | |
self.relu2_2 = nn.ReLU(inplace=True) | |
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 | |
# conv3 | |
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) | |
self.relu3_1 = nn.ReLU(inplace=True) | |
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) | |
self.relu3_2 = nn.ReLU(inplace=True) | |
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) | |
self.relu3_3 = nn.ReLU(inplace=True) | |
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 | |
# conv4 | |
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) | |
self.relu4_1 = nn.ReLU(inplace=True) | |
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) | |
self.relu4_2 = nn.ReLU(inplace=True) | |
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) | |
self.relu4_3 = nn.ReLU(inplace=True) | |
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 | |
# conv5 | |
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) | |
self.relu5_1 = nn.ReLU(inplace=True) | |
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) | |
self.relu5_2 = nn.ReLU(inplace=True) | |
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) | |
self.relu5_3 = nn.ReLU(inplace=True) | |
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 | |
# fc6 | |
self.fc6 = nn.Conv2d(512, 4096, 7) | |
self.relu6 = nn.ReLU(inplace=True) | |
self.drop6 = nn.Dropout2d() | |
# fc7 | |
self.fc7 = nn.Conv2d(4096, 4096, 1) | |
self.relu7 = nn.ReLU(inplace=True) | |
self.drop7 = nn.Dropout2d() | |
self.score_fr = nn.Conv2d(4096, n_class, 1) | |
self.score_pool3 = nn.Conv2d(256, n_class, 1) | |
self.score_pool4 = nn.Conv2d(512, n_class, 1) | |
self.upscore2 = nn.ConvTranspose2d( | |
n_class, n_class, 4, stride=2, bias=False) | |
self.upscore8 = nn.ConvTranspose2d( | |
n_class, n_class, 16, stride=8, bias=False) | |
self.upscore_pool4 = nn.ConvTranspose2d( | |
n_class, n_class, 4, stride=2, bias=False) | |
self._initialize_weights() | |
def _initialize_weights(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
m.weight.data.zero_() | |
if m.bias is not None: | |
m.bias.data.zero_() | |
if isinstance(m, nn.ConvTranspose2d): | |
assert m.kernel_size[0] == m.kernel_size[1] | |
initial_weight = get_upsampling_weight( | |
m.in_channels, m.out_channels, m.kernel_size[0]) | |
m.weight.data.copy_(initial_weight) | |
def forward(self, x): | |
h = x | |
h = self.relu1_1(self.conv1_1(h)) | |
h = self.relu1_2(self.conv1_2(h)) | |
h = self.pool1(h) | |
h = self.relu2_1(self.conv2_1(h)) | |
h = self.relu2_2(self.conv2_2(h)) | |
h = self.pool2(h) | |
h = self.relu3_1(self.conv3_1(h)) | |
h = self.relu3_2(self.conv3_2(h)) | |
h = self.relu3_3(self.conv3_3(h)) | |
h = self.pool3(h) | |
pool3 = h # 1/8 | |
h = self.relu4_1(self.conv4_1(h)) | |
h = self.relu4_2(self.conv4_2(h)) | |
h = self.relu4_3(self.conv4_3(h)) | |
h = self.pool4(h) | |
pool4 = h # 1/16 | |
h = self.relu5_1(self.conv5_1(h)) | |
h = self.relu5_2(self.conv5_2(h)) | |
h = self.relu5_3(self.conv5_3(h)) | |
h = self.pool5(h) | |
h = self.relu6(self.fc6(h)) | |
h = self.drop6(h) | |
h = self.relu7(self.fc7(h)) | |
h = self.drop7(h) | |
h = self.score_fr(h) | |
h = self.upscore2(h) | |
upscore2 = h # 1/16 | |
h = self.score_pool4(pool4) | |
h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] | |
score_pool4c = h # 1/16 | |
h = upscore2 + score_pool4c # 1/16 | |
h = self.upscore_pool4(h) | |
upscore_pool4 = h # 1/8 | |
h = self.score_pool3(pool3) | |
h = h[:, :, | |
9:9 + upscore_pool4.size()[2], | |
9:9 + upscore_pool4.size()[3]] | |
score_pool3c = h # 1/8 | |
h = upscore_pool4 + score_pool3c # 1/8 | |
h = self.upscore8(h) | |
h = h[:, :, 31:31 + x.size()[2], 31:31 + x.size()[3]].contiguous() | |
return h | |
def predict(self, x): | |
with torch.no_grad(): | |
x = self.forward(x) | |
return x | |
model = FCN8s(n_class=1) | |
def get_model_parameters(model, bias=False): | |
#print('hello') | |
#import torch.nn as nn | |
modules_skipped = ( | |
nn.ReLU, | |
nn.MaxPool2d, | |
nn.Dropout2d, | |
nn.Sequential, | |
#LuisNet, | |
FCN8s, | |
) | |
for m in model.modules(): | |
#print(m) | |
if isinstance(m, nn.Conv2d): | |
if bias: | |
yield m.bias | |
else: | |
#print(m.weight) | |
yield m.weight | |
elif isinstance(m, nn.ConvTranspose2d): | |
# weight is frozen because it is just a bilinear upsampling | |
if bias: | |
assert m.bias is None | |
elif isinstance(m, modules_skipped): | |
continue | |
else: | |
raise ValueError('Unexpected module: %s' % str(m)) | |
optimizer = torch.optim.Adam([ | |
{'params': get_model_parameters(model, bias=False), 'lr': 0.01}, | |
# decrease lr for encoder in order not to permute | |
# pre-trained weights with large gradients on training start | |
{'params': get_model_parameters(model, bias=True), 'lr': 0.01}, | |
]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment