Created
October 12, 2020 21:02
-
-
Save normster/a6b89ab6872566e8cee116f2ef9fc58b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision.transforms.functional as F | |
import numpy as np | |
class HED(nn.Module): | |
""" HED network. """ | |
def __init__(self): | |
super(HED, self).__init__() | |
# Layers. | |
self.conv1_1 = nn.Conv2d(3, 64, 3, padding=35) | |
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) | |
self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) | |
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) | |
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) | |
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) | |
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) | |
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) | |
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) | |
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) | |
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) | |
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) | |
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) | |
self.relu = nn.ReLU() | |
# Note: ceil_mode – when True, will use ceil instead of floor to compute the output shape. | |
# The reason to use ceil mode here is that later we need to upsample the feature maps and crop the results | |
# in order to have the same shape as the original image. If ceil mode is not used, the up-sampled feature | |
# maps will possibly be smaller than the original images. | |
self.maxpool = nn.MaxPool2d(2, stride=2, ceil_mode=True) | |
self.score_dsn1 = nn.Conv2d(64, 1, 1) # Out channels: 1. | |
self.score_dsn2 = nn.Conv2d(128, 1, 1) | |
self.score_dsn3 = nn.Conv2d(256, 1, 1) | |
self.score_dsn4 = nn.Conv2d(512, 1, 1) | |
self.score_dsn5 = nn.Conv2d(512, 1, 1) | |
self.score_final = nn.Conv2d(5, 1, 1) | |
# Fixed bilinear weights. | |
self.register_buffer('weight_deconv2', make_bilinear_weights(4, 1)) | |
self.register_buffer('weight_deconv3', make_bilinear_weights(8, 1)) | |
self.register_buffer('weight_deconv4', make_bilinear_weights(16, 1)) | |
self.register_buffer('weight_deconv5', make_bilinear_weights(32, 1)) | |
# Prepare for aligned crop. | |
self.crop1_margin, self.crop2_margin, self.crop3_margin, self.crop4_margin, self.crop5_margin = \ | |
self.prepare_aligned_crop() | |
# noinspection PyMethodMayBeStatic | |
def prepare_aligned_crop(self): | |
""" Prepare for aligned crop. """ | |
# Re-implement the logic in deploy.prototxt and | |
# /hed/src/caffe/layers/crop_layer.cpp of official repo. | |
# Other reference materials: | |
# hed/include/caffe/layer.hpp | |
# hed/include/caffe/vision_layers.hpp | |
# hed/include/caffe/util/coords.hpp | |
# https://groups.google.com/forum/#!topic/caffe-users/YSRYy7Nd9J8 | |
def map_inv(m): | |
""" Mapping inverse. """ | |
a, b = m | |
return 1 / a, -b / a | |
def map_compose(m1, m2): | |
""" Mapping compose. """ | |
a1, b1 = m1 | |
a2, b2 = m2 | |
return a1 * a2, a1 * b2 + b1 | |
def deconv_map(kernel_h, stride_h, pad_h): | |
""" Deconvolution coordinates mapping. """ | |
return stride_h, (kernel_h - 1) / 2 - pad_h | |
def conv_map(kernel_h, stride_h, pad_h): | |
""" Convolution coordinates mapping. """ | |
return map_inv(deconv_map(kernel_h, stride_h, pad_h)) | |
def pool_map(kernel_h, stride_h, pad_h): | |
""" Pooling coordinates mapping. """ | |
return conv_map(kernel_h, stride_h, pad_h) | |
x_map = (1, 0) | |
conv1_1_map = map_compose(conv_map(3, 1, 35), x_map) | |
conv1_2_map = map_compose(conv_map(3, 1, 1), conv1_1_map) | |
pool1_map = map_compose(pool_map(2, 2, 0), conv1_2_map) | |
conv2_1_map = map_compose(conv_map(3, 1, 1), pool1_map) | |
conv2_2_map = map_compose(conv_map(3, 1, 1), conv2_1_map) | |
pool2_map = map_compose(pool_map(2, 2, 0), conv2_2_map) | |
conv3_1_map = map_compose(conv_map(3, 1, 1), pool2_map) | |
conv3_2_map = map_compose(conv_map(3, 1, 1), conv3_1_map) | |
conv3_3_map = map_compose(conv_map(3, 1, 1), conv3_2_map) | |
pool3_map = map_compose(pool_map(2, 2, 0), conv3_3_map) | |
conv4_1_map = map_compose(conv_map(3, 1, 1), pool3_map) | |
conv4_2_map = map_compose(conv_map(3, 1, 1), conv4_1_map) | |
conv4_3_map = map_compose(conv_map(3, 1, 1), conv4_2_map) | |
pool4_map = map_compose(pool_map(2, 2, 0), conv4_3_map) | |
conv5_1_map = map_compose(conv_map(3, 1, 1), pool4_map) | |
conv5_2_map = map_compose(conv_map(3, 1, 1), conv5_1_map) | |
conv5_3_map = map_compose(conv_map(3, 1, 1), conv5_2_map) | |
score_dsn1_map = conv1_2_map | |
score_dsn2_map = conv2_2_map | |
score_dsn3_map = conv3_3_map | |
score_dsn4_map = conv4_3_map | |
score_dsn5_map = conv5_3_map | |
upsample2_map = map_compose(deconv_map(4, 2, 0), score_dsn2_map) | |
upsample3_map = map_compose(deconv_map(8, 4, 0), score_dsn3_map) | |
upsample4_map = map_compose(deconv_map(16, 8, 0), score_dsn4_map) | |
upsample5_map = map_compose(deconv_map(32, 16, 0), score_dsn5_map) | |
crop1_margin = int(score_dsn1_map[1]) | |
crop2_margin = int(upsample2_map[1]) | |
crop3_margin = int(upsample3_map[1]) | |
crop4_margin = int(upsample4_map[1]) | |
crop5_margin = int(upsample5_map[1]) | |
return crop1_margin, crop2_margin, crop3_margin, crop4_margin, crop5_margin | |
def forward(self, x): | |
# VGG-16 network. | |
image_h, image_w = x.shape[2], x.shape[3] | |
conv1_1 = self.relu(self.conv1_1(x)) | |
conv1_2 = self.relu(self.conv1_2(conv1_1)) # Side output 1. | |
pool1 = self.maxpool(conv1_2) | |
conv2_1 = self.relu(self.conv2_1(pool1)) | |
conv2_2 = self.relu(self.conv2_2(conv2_1)) # Side output 2. | |
pool2 = self.maxpool(conv2_2) | |
conv3_1 = self.relu(self.conv3_1(pool2)) | |
conv3_2 = self.relu(self.conv3_2(conv3_1)) | |
conv3_3 = self.relu(self.conv3_3(conv3_2)) # Side output 3. | |
pool3 = self.maxpool(conv3_3) | |
conv4_1 = self.relu(self.conv4_1(pool3)) | |
conv4_2 = self.relu(self.conv4_2(conv4_1)) | |
conv4_3 = self.relu(self.conv4_3(conv4_2)) # Side output 4. | |
pool4 = self.maxpool(conv4_3) | |
conv5_1 = self.relu(self.conv5_1(pool4)) | |
conv5_2 = self.relu(self.conv5_2(conv5_1)) | |
conv5_3 = self.relu(self.conv5_3(conv5_2)) # Side output 5. | |
score_dsn1 = self.score_dsn1(conv1_2) | |
score_dsn2 = self.score_dsn2(conv2_2) | |
score_dsn3 = self.score_dsn3(conv3_3) | |
score_dsn4 = self.score_dsn4(conv4_3) | |
score_dsn5 = self.score_dsn5(conv5_3) | |
upsample2 = torch.nn.functional.conv_transpose2d(score_dsn2, self.weight_deconv2, stride=2) | |
upsample3 = torch.nn.functional.conv_transpose2d(score_dsn3, self.weight_deconv3, stride=4) | |
upsample4 = torch.nn.functional.conv_transpose2d(score_dsn4, self.weight_deconv4, stride=8) | |
upsample5 = torch.nn.functional.conv_transpose2d(score_dsn5, self.weight_deconv5, stride=16) | |
# Aligned cropping. | |
crop1 = score_dsn1[:, :, self.crop1_margin:self.crop1_margin+image_h, | |
self.crop1_margin:self.crop1_margin+image_w] | |
crop2 = upsample2[:, :, self.crop2_margin:self.crop2_margin+image_h, | |
self.crop2_margin:self.crop2_margin+image_w] | |
crop3 = upsample3[:, :, self.crop3_margin:self.crop3_margin+image_h, | |
self.crop3_margin:self.crop3_margin+image_w] | |
crop4 = upsample4[:, :, self.crop4_margin:self.crop4_margin+image_h, | |
self.crop4_margin:self.crop4_margin+image_w] | |
crop5 = upsample5[:, :, self.crop5_margin:self.crop5_margin+image_h, | |
self.crop5_margin:self.crop5_margin+image_w] | |
# Concatenate according to channels. | |
fuse_cat = torch.cat((crop1, crop2, crop3, crop4, crop5), dim=1) | |
fuse = self.score_final(fuse_cat) # Shape: [batch_size, 1, image_h, image_w]. | |
return torch.sigmoid(fuse) | |
def make_bilinear_weights(size, num_channels): | |
""" Generate bi-linear interpolation weights as up-sampling filters (following FCN paper). """ | |
factor = (size + 1) // 2 | |
if size % 2 == 1: | |
center = factor - 1 | |
else: | |
center = factor - 0.5 | |
og = np.ogrid[:size, :size] | |
filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) | |
filt = torch.from_numpy(filt) | |
w = torch.zeros(num_channels, num_channels, size, size) | |
w.requires_grad = False # Set not trainable. | |
for i in range(num_channels): | |
for j in range(num_channels): | |
if i == j: | |
w[i, j] = filt | |
return w |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment