Skip to content

Instantly share code, notes, and snippets.

@lartpang
Last active April 2, 2019 08:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lartpang/77cf8495374dca9283c62f2f4feffd99 to your computer and use it in GitHub Desktop.
Save lartpang/77cf8495374dca9283c62f2f4feffd99 to your computer and use it in GitHub Desktop.
对3DGNN中使用的GNN的一些改动
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _quadruple
class MedianPool2d(nn.Module):
""" Median pool (usable as median filter when stride=1) module.
Args:
kernel_size: size of pooling kernel, int or 2-tuple
stride: pool stride, int or 2-tuple
padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
same: override padding and enforce same padding, boolean
"""
def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
super(MedianPool2d, self).__init__()
self.k = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _quadruple(padding) # convert to l, r, g, b
self.same = same
def _padding(self, x):
if self.same:
ih, iw = x.size()[2:]
if ih % self.stride[0] == 0:
ph = max(self.k[0] - self.stride[0], 0)
else:
ph = max(self.k[0] - (ih % self.stride[0]), 0)
if iw % self.stride[1] == 0:
pw = max(self.k[1] - self.stride[1], 0)
else:
pw = max(self.k[1] - (iw % self.stride[1]), 0)
pl = pw // 2
pr = pw - pl
pt = ph // 2
pb = ph - pt
padding = (pl, pr, pt, pb)
else:
padding = self.padding
return padding
def forward(self, x):
# using existing pytorch functions and tensor ops so that we get autograd,
# would likely be more efficient to implement from scratch at C/Cuda level
x = F.pad(x, self._padding(x), mode='reflect')
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1],
self.stride[1])
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
return x
class GNNV1(nn.Module):
def __init__(self, mlp_num, k, gnn_iterations, device):
super(GNNV1, self).__init__()
self.k = k
self.device = device
self.gnn_iternum = gnn_iterations
self.median_pool = MedianPool2d(
kernel_size=32, stride=32, padding=0, same=False)
self.g_rnn_layers = nn.ModuleList(
[nn.Linear(2048, 2048) for l in range(mlp_num)])
self.g_rnn_actfs = nn.ModuleList(
[nn.ReLU(inplace=True) for l in range(mlp_num)])
self.q_rnn_layer = nn.Linear(4096, 2048)
self.q_rnn_actf = nn.ReLU(inplace=True)
self.output_conv = nn.Conv2d(4096, 2048, 3, stride=1, padding=1,
bias=True)
self.device = device
def forward(self, cnn_encoder_output, original_input):
"""
:param cnn_encoder_output: CNN编码器输出
:param original_input: 原始图像输入
:param gnn_iterations: GNN迭代次数
:param k: K邻近聚类数
"""
# extract for convenience
N, C, H, W = cnn_encoder_output.size()
K = self.k
# 这里为了获得三维空间中的位置坐标,这里使用RGB色彩空间的标定
proj_3d = self.median_pool(original_input) # N 3 H W
# N H*W 3
proj_3d = proj_3d.view(N, 3, (H * W)).transpose(2, 1).contiguous()
# get k nearest neighbors
knn = self.__get_knn_indices(proj_3d) # N HW K
knn = knn.view(N, H * W * K).long() # N HWK
# prepare CNN encoded features for RNN
h = cnn_encoder_output # N C H W
# 调整维度之后, 一般需要在contiguous后才能用view
h = h.permute(0, 2, 3, 1).contiguous() # N H W C
h = h.view(N, (H * W), C) # N HW C
# aggregate and iterate messages in m, keep original CNN features h for later
m = h.clone() # N HW C
# loop over timestamps to unroll
for i in range(self.gnn_iternum):
# do this for every sample in batch, not nice, but I don't know
# how to use index_select batchwise
# todo: 这里可以考虑提速, 使用batch级别的索引选择
for n in range(N):
# fetch features from nearest neighbors
# 从各个batch的邻居中获取信息, 这里允许重复索引
neighbor_f = torch.index_select(
h[n], 0, knn[n]).view(H * W, K, C) # HW K C
# run neighbor features through MLP g and activation function
# todo: 多层MLP的处理, 是否可以替换成卷及操作?
# 使用多个线性层+ReLU
for g_line, g_actf in zip(self.g_rnn_layers, self.g_rnn_actfs):
# 对每个HW中的点的KxC向量表示的邻域信息进行加权计算
# 线性层(全连接)要求的输入是HWxKxCin=>HWxKxCout
neighbor_f = g_line(neighbor_f)
neighbor_f = g_actf(neighbor_f)
# HW K C
# average over activated neighbors
m[n] = torch.mean(neighbor_f, dim=1) # HW C
# concatenate current state with messages
concat = torch.cat((h, m), 2) # N HW 2C
# get new features by running MLP q and activation function
h = self.q_rnn_actf(self.q_rnn_layer(concat)) # N HW C
# format RNN activations back to image, concatenate original CNN embedding, return
h = h.view(N, H, W, C).permute(0, 3, 1, 2).contiguous() # N C H W
output = self.output_conv(
torch.cat((cnn_encoder_output, h), 1)) # N 2C H W
return output
def __get_knn_indices(self, batch_mat):
r = torch.bmm(batch_mat, batch_mat.permute(0, 2, 1))
N, HW, _ = r.size()
batch_indices = torch.zeros((N, HW, self.k)).to(self.device)
for idx, val in enumerate(r):
# get the diagonal elements
diag = val.diag().unsqueeze(0)
diag = diag.expand_as(val)
# compute the distance matrix
D = (diag + diag.t() - 2 * val).sqrt()
topk, indices = torch.topk(D, k=self.k, largest=False)
batch_indices[idx] = indices.data
return batch_indices
class GNNV2(nn.Module):
def __init__(self, mlp_num, k, gnn_iterations, device):
super(GNNV2, self).__init__()
self.k = k
self.device = device
self.gnn_iternum = gnn_iterations
self.median_pool = MedianPool2d(
kernel_size=32, stride=32, padding=0, same=False)
g_rnn_conv_list = [
nn.Conv2d(2048 * self.k, 2048 * self.k, 1),
nn.BatchNorm2d(2048 * self.k),
nn.ReLU(inplace=True)
] * mlp_num
self.g_rnn_conv = nn.Sequential(*g_rnn_conv_list)
self.q_rnn_layer = nn.Linear(4096, 2048)
self.q_rnn_actf = nn.ReLU(inplace=True)
self.output_conv = nn.Conv2d(4096, 2048, 3, stride=1, padding=1,
bias=True)
def forward(self, cnn_encoder_output, original_input):
"""
:param cnn_encoder_output: CNN编码器输出
:param original_input: 原始图像输入
:param gnn_iterations: GNN迭代次数
:param k: K邻近聚类数
"""
# extract for convenience
N, C, H, W = cnn_encoder_output.size()
K = self.k
# 这里为了获得三维空间中的位置坐标,这里使用RGB色彩空间的标定
proj_3d = self.median_pool(original_input) # N 3 H W
# N H*W 3
proj_3d = proj_3d.view(N, 3, (H * W)).transpose(2, 1).contiguous()
# get k nearest neighbors
knn = self.__get_knn_indices(proj_3d) # N HW K
knn = knn.view(N * H * W * K).long() # NHWK
# prepare CNN encoded features for RNN
h = cnn_encoder_output # N C H W
# 调整维度之后, 一般需要在contiguous后才能用view
h = h.permute(0, 2, 3, 1).contiguous() # N H W C
# loop over timestamps to unroll
for i in range(self.gnn_iternum):
# do this for every sample in batch, not nice, but I don't know
# how to use index_select batchwise
# fetch features from nearest neighbors
# 从各个batch的邻居中获取信息, 这里允许重复索引
# N H W K*C
h = h.view(N * (H * W), C) # NHW C
neighbor_f = torch.index_select(h, 0, knn).view(N, H, W, K * C)
neighbor_f = neighbor_f.permute(0, 3, 1, 2)
neighbor_f = self.g_rnn_conv(neighbor_f)
neighbor_f = neighbor_f.permute(0, 2, 3, 1).contiguous() # N H W KC
neighbor_f = neighbor_f.view(N, H * W, K, C)
# aggregate and iterate messages in m, keep original CNN features h for later
m = torch.mean(neighbor_f, dim=2)
h = h.view(N, (H * W), C)
# concatenate current state with messages
concat = torch.cat((h, m), 2) # N HW 2C
# get new features by running MLP q and activation function
h = self.q_rnn_actf(self.q_rnn_layer(concat)) # N HW C
# format RNN activations back to image, concatenate original CNN embedding, return
h = h.view(N, H, W, C).permute(0, 3, 1, 2).contiguous() # N C H W
output = self.output_conv(
torch.cat((cnn_encoder_output, h), 1)) # N 2C H W
return output
# adapted from
# https://discuss.pytorch.org/t/build-your-own-loss-function-in-pytorch/235/6
# (x - y)^2 = x^2 - 2*x*y + y^2
def __get_knn_indices(self, batch_mat):
r = torch.bmm(batch_mat, batch_mat.permute(0, 2, 1))
N, HW, _ = r.size()
batch_indices = torch.zeros((N, HW, self.k)).to(self.device)
for idx, val in enumerate(r):
# get the diagonal elements
diag = val.diag().unsqueeze(0)
diag = diag.expand_as(val)
# compute the distance matrix
D = (diag + diag.t() - 2 * val).sqrt()
topk, indices = torch.topk(D, k=self.k, largest=False)
batch_indices[idx] = indices.data
return batch_indices
class GNNV3(nn.Module):
def __init__(self, mlp_num, k, gnn_iterations, device):
super(GNNV3, self).__init__()
self.k = k
self.device = device
self.gnn_iternum = gnn_iterations
self.median_pool = MedianPool2d(
kernel_size=32, stride=32, padding=0, same=False)
g_rnn_conv_list = [
nn.Conv2d(2048 * self.k, 2048 * self.k, 1),
nn.BatchNorm2d(2048 * self.k),
nn.ReLU(inplace=True)
] * mlp_num
self.g_rnn_conv = nn.Sequential(*g_rnn_conv_list)
self.q_rnn_conv = nn.Sequential(
nn.Conv2d(4096, 2048, 1),
nn.BatchNorm2d(2048),
nn.ReLU(inplace=True)
)
self.output_conv = nn.Conv2d(4096, 2048, 3,
stride=1,
padding=1,
bias=True)
def forward(self, cnn_encoder_output, original_input):
"""
:param cnn_encoder_output: CNN编码器输出
:param original_input: 原始图像输入
:param gnn_iterations: GNN迭代次数
:param k: K邻近聚类数
"""
# extract for convenience
N, C, H, W = cnn_encoder_output.size()
K = self.k
# 这里为了获得三维空间中的位置坐标,这里使用RGB色彩空间的标定
proj_3d = self.median_pool(original_input) # N 3 H W
# N H*W 3
proj_3d = proj_3d.view(N, 3, (H * W)).transpose(2, 1).contiguous()
# get k nearest neighbors
knn = self.__get_knn_indices(proj_3d) # N HW K
knn = knn.view(N * H * W * K).long() # NHWK
# prepare CNN encoded features for RNN
h = cnn_encoder_output # N C H W
# 调整维度之后, 一般需要在contiguous后才能用view
h = h.permute(0, 2, 3, 1).contiguous() # N H W C
# loop over timestamps to unroll
for i in range(self.gnn_iternum):
# do this for every sample in batch, not nice, but I don't know
# how to use index_select batchwise
# fetch features from nearest neighbors
# 从各个batch的邻居中获取信息, 这里允许重复索引
# N H W K*C
h = h.view(N * (H * W), C) # NHW C
neighbor_f = torch.index_select(h, 0, knn).view(N, H, W, K * C)
neighbor_f = neighbor_f.permute(0, 3, 1, 2) # N KC H W
neighbor_f = self.g_rnn_conv(neighbor_f)
neighbor_f = neighbor_f.permute(0, 2, 3, 1).contiguous() # N H W KC
neighbor_f = neighbor_f.view(N, H * W, K, C)
# aggregate and iterate messages in m, keep original CNN features h for later
m = torch.mean(neighbor_f, dim=2)
h = h.view(N, (H * W), C)
# concatenate current state with messages
concat = torch.cat((h, m), 2).view(N, H, W, 2 * C) # N HW 2C
concat = concat.permute(0, 3, 1, 2)
# get new features by running MLP q and activation function
h = self.q_rnn_conv(concat) # N, C, H, W
h = h.permute(0, 2, 3, 1).contiguous() # N H W C
# format RNN activations back to image, concatenate original CNN embedding, return
h = h.view(N, H, W, C).permute(0, 3, 1, 2).contiguous() # N C H W
output = self.output_conv(
torch.cat((cnn_encoder_output, h), 1)) # N 2C H W
return output
# adapted from
# https://discuss.pytorch.org/t/build-your-own-loss-function-in-pytorch/235/6
# (x - y)^2 = x^2 - 2*x*y + y^2
def __get_knn_indices(self, batch_mat):
r = torch.bmm(batch_mat, batch_mat.permute(0, 2, 1))
N, HW, _ = r.size()
batch_indices = torch.zeros((N, HW, self.k)).to(self.device)
for idx, val in enumerate(r):
# get the diagonal elements
diag = val.diag().unsqueeze(0)
diag = diag.expand_as(val)
# compute the distance matrix
D = (diag + diag.t() - 2 * val).sqrt()
topk, indices = torch.topk(D, k=self.k, largest=False)
batch_indices[idx] = indices.data
return batch_indices
if __name__ == '__main__':
device = torch.device('cuda:0')
gnn = GNNV2(3, k=12, gnn_iterations=3, device=device).to(device)
cnn_encoder_output = torch.randint(0, 255, size=(5, 2048, 7, 7),
dtype=torch.float32).to(device)
original_input = torch.randint(0, 255, size=(5, 3, 224, 224),
dtype=torch.float32).to(device)
gnn_iterations = 1
import time
start = time.time()
output = gnn(cnn_encoder_output,
original_input)
print(output.size())
# v1 0.054094552993774414
# v2 0.01854419708251953
# v3 0.02263331413269043
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment