-
-
Save d4l3k/d664de8f68607992edc7e09c1991d131 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import torch.nn as nn | |
from torchvision import models | |
from torchvision import transforms | |
import torch.nn.functional as F | |
import numpy as np | |
TRANSFORMER_DIM = 120 | |
FPN_DIM = 128 | |
CAMERAS = ( | |
"main", | |
"leftrepeater", | |
"rightrepeater", | |
"leftpillar", | |
"rightpillar", | |
) | |
class CamEncoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = models.resnet18(pretrained=False, progress=True) | |
state_dict = torch.load( | |
"../../monodepth2/models/tesla18/encoder.pth", | |
map_location=torch.device('cpu'), | |
) | |
filtered_dict = {} | |
PREFIX = "module.encoder." | |
for name, param in state_dict.items(): | |
if PREFIX not in name: | |
continue | |
new_name = name[len(PREFIX) :] | |
filtered_dict[new_name] = param | |
self.model.load_state_dict(filtered_dict) | |
def forward(self, x): | |
# adapted from torchvision.models.Resnet | |
x = self.model.conv1(x) | |
x = self.model.bn1(x) | |
x = self.model.relu(x) | |
x = self.model.maxpool(x) | |
f1 = self.model.layer1(x) | |
f2 = self.model.layer2(f1) | |
f3 = self.model.layer3(f2) | |
f4 = self.model.layer4(f3) | |
return f1, f2, f3, f4 | |
class DepthwiseConvBlock(nn.Module): | |
""" | |
Depthwise seperable convolution. | |
From https://github.com/tristandb/EfficientDet-PyTorch/blob/master/bifpn.py | |
LGPL licensed | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, freeze_bn=False): | |
super(DepthwiseConvBlock,self).__init__() | |
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, | |
padding, dilation, groups=in_channels, bias=False) | |
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, | |
stride=1, padding=0, dilation=1, groups=1, bias=False) | |
self.bn = nn.BatchNorm2d(out_channels, momentum=0.9997, eps=4e-5) | |
self.act = nn.SiLU() | |
def forward(self, inputs): | |
x = self.depthwise(inputs) | |
x = self.pointwise(x) | |
x = self.bn(x) | |
return self.act(x) | |
class BiFPNBlock(nn.Module): | |
""" | |
Bi-directional Feature Pyramid Network | |
From https://github.com/tristandb/EfficientDet-PyTorch/blob/master/bifpn.py | |
LGPL licensed | |
""" | |
def __init__(self, feature_size=64, epsilon=0.0001): | |
super(BiFPNBlock, self).__init__() | |
self.epsilon = epsilon | |
self.p3_td = DepthwiseConvBlock(feature_size, feature_size) | |
self.p4_td = DepthwiseConvBlock(feature_size, feature_size) | |
self.p6_td = DepthwiseConvBlock(feature_size, feature_size) | |
self.p4_out = DepthwiseConvBlock(feature_size, feature_size) | |
self.p6_out = DepthwiseConvBlock(feature_size, feature_size) | |
self.p7_out = DepthwiseConvBlock(feature_size, feature_size) | |
# TODO: Init weights | |
self.w1 = nn.Parameter(torch.Tensor(2, 4)) | |
self.w1_relu = nn.SiLU() | |
self.w2 = nn.Parameter(torch.Tensor(3, 4)) | |
self.w2_relu = nn.SiLU() | |
def forward(self, inputs): | |
p3_x, p4_x, p6_x, p7_x = inputs | |
# Calculate Top-Down Pathway | |
w1 = self.w1_relu(self.w1) | |
w1 /= torch.sum(w1, dim=0) + self.epsilon | |
w2 = self.w2_relu(self.w2) | |
w2 /= torch.sum(w2, dim=0) + self.epsilon | |
p7_td = p7_x | |
p6_td = self.p6_td(w1[0, 0] * p6_x + w1[1, 0] * F.interpolate(p7_td, scale_factor=2)) | |
p4_td = self.p4_td(w1[0, 2] * p4_x + w1[1, 2] * F.interpolate(p6_td, scale_factor=2)) | |
p3_td = self.p3_td(w1[0, 3] * p3_x + w1[1, 3] * F.interpolate(p4_td, scale_factor=2)) | |
# Calculate Bottom-Up Pathway | |
p3_out = p3_td | |
p4_out = self.p4_out(w2[0, 0] * p4_x + w2[1, 0] * p4_td + w2[2, 0] * nn.Upsample(scale_factor=0.5)(p3_out)) | |
p6_out = self.p6_out(w2[0, 2] * p6_x + w2[1, 2] * p6_td + w2[2, 2] * | |
nn.Upsample(scale_factor=0.5)(p4_out)) | |
p7_out = self.p7_out(w2[0, 3] * p7_x + w2[1, 3] * p7_td + w2[2, 3] * nn.Upsample(scale_factor=0.5)(p6_out)) | |
return [p3_out, p4_out, p6_out, p7_out] | |
class TransformerBlock(nn.Module): | |
def __init__(self, dims): | |
super().__init__() | |
self.transformer = nn.MultiheadAttention( | |
embed_dim=TRANSFORMER_DIM, | |
num_heads=12, | |
batch_first=True, | |
) | |
self.context_encoder = nn.Sequential( | |
nn.MaxPool2d(dims), | |
) | |
positional_encoding = torch.zeros((1, 4, 144, 80)) | |
z_range = torch.arange(0, 144) / 143 * 2 * math.pi | |
x_range = torch.arange(0, 80) / 79 * 2 * math.pi | |
positional_encoding[0, 0, :, :] = torch.sin(z_range).unsqueeze(1) | |
positional_encoding[0, 1, :, :] = torch.cos(z_range).unsqueeze(1) | |
positional_encoding[0, 2, :, :] = torch.sin(x_range).unsqueeze(0) | |
positional_encoding[0, 3, :, :] = torch.cos(x_range).unsqueeze(0) | |
self.register_buffer( | |
"positional_encoding", positional_encoding, persistent=False | |
) | |
self.query_encoder = nn.Sequential( | |
nn.Conv2d(FPN_DIM+4, FPN_DIM, 1), | |
nn.SiLU(), | |
nn.Conv2d(FPN_DIM, TRANSFORMER_DIM, 1), | |
) | |
self.key_encoder = nn.Sequential( | |
nn.Conv1d(FPN_DIM, TRANSFORMER_DIM, 1), | |
) | |
self.value_encoder = nn.Sequential( | |
nn.Conv1d(FPN_DIM, TRANSFORMER_DIM, 1), | |
) | |
def forward(self, x): | |
BS = len(x) | |
context = self.context_encoder(x) | |
context = torch.tile(context, (1, 1, 144, 80)) | |
pos_enc = self.positional_encoding.tile(len(context), 1, 1, 1) | |
context = torch.concat((context, pos_enc), dim=1) | |
query = self.query_encoder(context).permute(0, 2, 3, 1).reshape(BS, -1, TRANSFORMER_DIM) | |
x = x.reshape(BS, FPN_DIM, -1) | |
key = self.key_encoder(x).permute(0, 2, 1) | |
value = self.value_encoder(x).permute(0, 2, 1) | |
bev, weights = self.transformer( | |
query, | |
key, | |
value, | |
need_weights=False, | |
) | |
bev = bev.reshape(BS, 144, 80, TRANSFORMER_DIM) | |
bev = bev.permute(0, 3, 1, 2) | |
return bev | |
class VoxelDecoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.f4_encoder = nn.Sequential( | |
nn.Conv2d(2560, 512, 1), | |
nn.SiLU(), | |
nn.Conv2d(512, FPN_DIM, 1), | |
nn.SiLU(), | |
) | |
self.f3_encoder = nn.Sequential( | |
nn.Conv2d(1280, 256, 1), | |
nn.SiLU(), | |
nn.Conv2d(256, FPN_DIM, 1), | |
nn.SiLU(), | |
) | |
self.f2_encoder = nn.Sequential( | |
nn.Conv2d(640, FPN_DIM, 1), | |
nn.SiLU(), | |
) | |
self.f1_encoder = nn.Sequential( | |
nn.Conv2d(320, FPN_DIM, 1), | |
nn.SiLU(), | |
) | |
bifpns = [BiFPNBlock(FPN_DIM) for i in range(2)] | |
self.bifpn = nn.Sequential(*bifpns) | |
#self.transformer1 = TransformerBlock((104, 160)) | |
#self.transformer2 = TransformerBlock((52, 80)) | |
self.transformer3 = TransformerBlock((26, 40)) | |
self.transformer4 = TransformerBlock((13, 20)) | |
self.decoder = nn.Sequential( | |
nn.Conv2d(TRANSFORMER_DIM * 2, TRANSFORMER_DIM, 1), | |
nn.SiLU(), | |
nn.Conv2d(TRANSFORMER_DIM, 24, 1), | |
nn.SiLU(), | |
nn.Conv2d(24, 12, 1), | |
) | |
def forward(self, features): | |
f1_stack = [] | |
f2_stack = [] | |
f3_stack = [] | |
f4_stack = [] | |
for cam in CAMERAS: | |
f1, f2, f3, f4 = features[cam] | |
f1_stack.append(f1) | |
f2_stack.append(f2) | |
f3_stack.append(f3) | |
f4_stack.append(f4) | |
x1 = self.f1_encoder(torch.concat(f1_stack, dim=1)) | |
x2 = self.f2_encoder(torch.concat(f2_stack, dim=1)) | |
x3 = self.f3_encoder(torch.concat(f3_stack, dim=1)) | |
x4 = self.f4_encoder(torch.concat(f4_stack, dim=1)) | |
x1, x2, x3, x4 = self.bifpn((x1, x2, x3, x4)) | |
#bev1 = self.transformer1(x1) | |
#bev2 = self.transformer2(x2) | |
bev3 = self.transformer3(x3) | |
bev4 = self.transformer4(x4) | |
bev = torch.concat((bev3, bev4), dim=1) | |
voxels = self.decoder(bev) | |
return voxels | |
class VoxelNet(nn.Module): | |
def __init__(self, freeze_cams=True): | |
super().__init__() | |
self.models = nn.ModuleDict({cam: CamEncoder() for cam in CAMERAS}) | |
self.decoder = VoxelDecoder() | |
def freeze_cams(self, freeze=True): | |
# freeze the encoders | |
for model in self.models.values(): | |
for param in model.parameters(): | |
param.requires_grad = not freeze | |
def forward(self, imgs): | |
features = {name: self.models[name](img) for name, img in imgs.items()} | |
return self.decoder(features) | |
def normalize(tensor): | |
tensor = torch.from_numpy(tensor.astype(np.float32)) | |
# convert to CHW | |
tensor = tensor.permute(2, 0, 1) | |
return transforms.functional.normalize( | |
tensor, | |
(18069, 24765, 18272), | |
(1739, 2487, 1844), | |
inplace=True, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment