Skip to content

Instantly share code, notes, and snippets.

@vlasenkov
Created February 1, 2019 11:19
Show Gist options
  • Save vlasenkov/5338ccc7e2afcf9baa38326dc21fe403 to your computer and use it in GitHub Desktop.
Save vlasenkov/5338ccc7e2afcf9baa38326dc21fe403 to your computer and use it in GitHub Desktop.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
import numpy as np
def get_upsampling_weight(in_channels, out_channels, kernel_size):
"""Make a 2D bilinear kernel suitable for upsampling"""
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:kernel_size, :kernel_size]
filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64)
weight[range(in_channels), range(out_channels), :, :] = filt
return torch.from_numpy(weight).float()
class FCN8s(nn.Module):
def __init__(self, n_classes=21, learned_billinear=True, dropout=0.5):
super(FCN8s, self).__init__()
self.learned_billinear = learned_billinear
self.n_classes = n_classes
self.conv_block1 = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=100),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, stride=2, ceil_mode=True),
)
self.conv_block2 = nn.Sequential(
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, stride=2, ceil_mode=True),
)
self.conv_block3 = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, stride=2, ceil_mode=True),
)
self.conv_block4 = nn.Sequential(
nn.Conv2d(256, 512, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, stride=2, ceil_mode=True),
)
self.conv_block5 = nn.Sequential(
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, stride=2, ceil_mode=True),
)
self.classifier = nn.Sequential(
nn.Conv2d(512, 4096, 7),
nn.ReLU(inplace=True),
nn.Dropout2d(dropout),
nn.Conv2d(4096, 4096, 1),
nn.ReLU(inplace=True),
nn.Dropout2d(dropout),
nn.Conv2d(4096, self.n_classes, 1),
)
self.score_pool4 = nn.Conv2d(512, self.n_classes, 1)
self.score_pool3 = nn.Conv2d(256, self.n_classes, 1)
if self.learned_billinear:
self.upscore2 = nn.ConvTranspose2d(
self.n_classes, self.n_classes, 4, stride=2, bias=False
)
self.upscore4 = nn.ConvTranspose2d(
self.n_classes, self.n_classes, 4, stride=2, bias=False
)
self.upscore8 = nn.ConvTranspose2d(
self.n_classes, self.n_classes, 16, stride=8, bias=False
)
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
m.weight.data.copy_(
get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0])
)
def forward(self, x):
conv1 = self.conv_block1(x)
conv2 = self.conv_block2(conv1)
conv3 = self.conv_block3(conv2)
conv4 = self.conv_block4(conv3)
conv5 = self.conv_block5(conv4)
score = self.classifier(conv5)
if self.learned_billinear:
upscore2 = self.upscore2(score)
score_pool4c = self.score_pool4(conv4)[
:, :, 5: 5 + upscore2.size()[2], 5: 5 + upscore2.size()[3]
]
upscore_pool4 = self.upscore4(upscore2 + score_pool4c)
score_pool3c = self.score_pool3(conv3)[
:, :, 9: 9 + upscore_pool4.size()[2], 9: 9 + upscore_pool4.size()[3]
]
out = self.upscore8(score_pool3c + upscore_pool4)[
:, :, 31: 31 + x.size()[2], 31: 31 + x.size()[3]
]
return out.contiguous()
else:
score_pool4 = self.score_pool4(conv4)
score_pool3 = self.score_pool3(conv3)
score = F.upsample(score, score_pool4.size()[2:])
score += score_pool4
score = F.upsample(score, score_pool3.size()[2:])
score += score_pool3
out = F.upsample(score, x.size()[2:])
return out
def init_vgg16_params(self, vgg16, copy_fc8=True):
blocks = [
self.conv_block1,
self.conv_block2,
self.conv_block3,
self.conv_block4,
self.conv_block5,
]
ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]]
features = list(vgg16.features.children())
for idx, conv_block in enumerate(blocks):
for l1, l2 in zip(features[ranges[idx][0]: ranges[idx][1]], conv_block):
if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
assert l1.weight.size() == l2.weight.size()
assert l1.bias.size() == l2.bias.size()
l2.weight.data = l1.weight.data
l2.bias.data = l1.bias.data
for i1, i2 in zip([0, 3], [0, 3]):
l1 = vgg16.classifier[i1]
l2 = self.classifier[i2]
l2.weight.data = l1.weight.data.view(l2.weight.size())
l2.bias.data = l1.bias.data.view(l2.bias.size())
n_class = self.classifier[6].weight.size()[0]
if copy_fc8:
l1 = vgg16.classifier[6]
l2 = self.classifier[6]
l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size())
l2.bias.data = l1.bias.data[:n_class]
nn.init.zeros_(self.score_pool4.weight)
nn.init.zeros_(self.score_pool4.bias)
nn.init.zeros_(self.score_pool3.weight)
nn.init.zeros_(self.score_pool3.bias)
def init_fcn16s_params(self, fcn16s):
self.conv_block1.load_state_dict(fcn16s.conv_block1.state_dict())
self.conv_block2.load_state_dict(fcn16s.conv_block2.state_dict())
self.conv_block3.load_state_dict(fcn16s.conv_block3.state_dict())
self.conv_block4.load_state_dict(fcn16s.conv_block4.state_dict())
self.conv_block5.load_state_dict(fcn16s.conv_block5.state_dict())
self.classifier.load_state_dict(fcn16s.classifier.state_dict())
self.score_pool4.load_state_dict(fcn16s.score_pool4.state_dict())
nn.init.zeros_(self.score_pool3.weight)
nn.init.zeros_(self.score_pool3.bias)
BYTES_IN_GB = 1024 ** 3
def memuse():
return 'ALLOCATED: {:>6.3f} ({:>6.3f}) CACHED: {:>6.3f} ({:>6.3f})'.format(
torch.cuda.memory_allocated() / BYTES_IN_GB,
torch.cuda.max_memory_allocated() / BYTES_IN_GB,
torch.cuda.memory_cached() / BYTES_IN_GB,
torch.cuda.max_memory_cached() / BYTES_IN_GB,
)
model = FCN8s(21)
model.cuda()
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss(ignore_index=255)
for i in range(4):
img = torch.randn(1, 3, 128, 128).cuda()
lbl = torch.randint(0, 5, (1, 128, 128)).cuda()
optimizer.zero_grad()
out = model(img)
loss = loss_fn(input=out, target=lbl)
loss.backward()
optimizer.step()
print(memuse())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment