Skip to content

Instantly share code, notes, and snippets.

@d4l3k

d4l3k/models.py Secret

Created Jan 6, 2022
Embed
What would you like to do?
import math
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
TRANSFORMER_DIM = 120
FPN_DIM = 128
CAMERAS = (
"main",
"leftrepeater",
"rightrepeater",
"leftpillar",
"rightpillar",
)
class CamEncoder(nn.Module):
def __init__(self):
super().__init__()
self.model = models.resnet18(pretrained=False, progress=True)
state_dict = torch.load(
"../../monodepth2/models/tesla18/encoder.pth",
map_location=torch.device('cpu'),
)
filtered_dict = {}
PREFIX = "module.encoder."
for name, param in state_dict.items():
if PREFIX not in name:
continue
new_name = name[len(PREFIX) :]
filtered_dict[new_name] = param
self.model.load_state_dict(filtered_dict)
def forward(self, x):
# adapted from torchvision.models.Resnet
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
f1 = self.model.layer1(x)
f2 = self.model.layer2(f1)
f3 = self.model.layer3(f2)
f4 = self.model.layer4(f3)
return f1, f2, f3, f4
class DepthwiseConvBlock(nn.Module):
"""
Depthwise seperable convolution.
From https://github.com/tristandb/EfficientDet-PyTorch/blob/master/bifpn.py
LGPL licensed
"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, freeze_bn=False):
super(DepthwiseConvBlock,self).__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride,
padding, dilation, groups=in_channels, bias=False)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=1, padding=0, dilation=1, groups=1, bias=False)
self.bn = nn.BatchNorm2d(out_channels, momentum=0.9997, eps=4e-5)
self.act = nn.SiLU()
def forward(self, inputs):
x = self.depthwise(inputs)
x = self.pointwise(x)
x = self.bn(x)
return self.act(x)
class BiFPNBlock(nn.Module):
"""
Bi-directional Feature Pyramid Network
From https://github.com/tristandb/EfficientDet-PyTorch/blob/master/bifpn.py
LGPL licensed
"""
def __init__(self, feature_size=64, epsilon=0.0001):
super(BiFPNBlock, self).__init__()
self.epsilon = epsilon
self.p3_td = DepthwiseConvBlock(feature_size, feature_size)
self.p4_td = DepthwiseConvBlock(feature_size, feature_size)
self.p6_td = DepthwiseConvBlock(feature_size, feature_size)
self.p4_out = DepthwiseConvBlock(feature_size, feature_size)
self.p6_out = DepthwiseConvBlock(feature_size, feature_size)
self.p7_out = DepthwiseConvBlock(feature_size, feature_size)
# TODO: Init weights
self.w1 = nn.Parameter(torch.Tensor(2, 4))
self.w1_relu = nn.SiLU()
self.w2 = nn.Parameter(torch.Tensor(3, 4))
self.w2_relu = nn.SiLU()
def forward(self, inputs):
p3_x, p4_x, p6_x, p7_x = inputs
# Calculate Top-Down Pathway
w1 = self.w1_relu(self.w1)
w1 /= torch.sum(w1, dim=0) + self.epsilon
w2 = self.w2_relu(self.w2)
w2 /= torch.sum(w2, dim=0) + self.epsilon
p7_td = p7_x
p6_td = self.p6_td(w1[0, 0] * p6_x + w1[1, 0] * F.interpolate(p7_td, scale_factor=2))
p4_td = self.p4_td(w1[0, 2] * p4_x + w1[1, 2] * F.interpolate(p6_td, scale_factor=2))
p3_td = self.p3_td(w1[0, 3] * p3_x + w1[1, 3] * F.interpolate(p4_td, scale_factor=2))
# Calculate Bottom-Up Pathway
p3_out = p3_td
p4_out = self.p4_out(w2[0, 0] * p4_x + w2[1, 0] * p4_td + w2[2, 0] * nn.Upsample(scale_factor=0.5)(p3_out))
p6_out = self.p6_out(w2[0, 2] * p6_x + w2[1, 2] * p6_td + w2[2, 2] *
nn.Upsample(scale_factor=0.5)(p4_out))
p7_out = self.p7_out(w2[0, 3] * p7_x + w2[1, 3] * p7_td + w2[2, 3] * nn.Upsample(scale_factor=0.5)(p6_out))
return [p3_out, p4_out, p6_out, p7_out]
class TransformerBlock(nn.Module):
def __init__(self, dims):
super().__init__()
self.transformer = nn.MultiheadAttention(
embed_dim=TRANSFORMER_DIM,
num_heads=12,
batch_first=True,
)
self.context_encoder = nn.Sequential(
nn.MaxPool2d(dims),
)
positional_encoding = torch.zeros((1, 4, 144, 80))
z_range = torch.arange(0, 144) / 143 * 2 * math.pi
x_range = torch.arange(0, 80) / 79 * 2 * math.pi
positional_encoding[0, 0, :, :] = torch.sin(z_range).unsqueeze(1)
positional_encoding[0, 1, :, :] = torch.cos(z_range).unsqueeze(1)
positional_encoding[0, 2, :, :] = torch.sin(x_range).unsqueeze(0)
positional_encoding[0, 3, :, :] = torch.cos(x_range).unsqueeze(0)
self.register_buffer(
"positional_encoding", positional_encoding, persistent=False
)
self.query_encoder = nn.Sequential(
nn.Conv2d(FPN_DIM+4, FPN_DIM, 1),
nn.SiLU(),
nn.Conv2d(FPN_DIM, TRANSFORMER_DIM, 1),
)
self.key_encoder = nn.Sequential(
nn.Conv1d(FPN_DIM, TRANSFORMER_DIM, 1),
)
self.value_encoder = nn.Sequential(
nn.Conv1d(FPN_DIM, TRANSFORMER_DIM, 1),
)
def forward(self, x):
BS = len(x)
context = self.context_encoder(x)
context = torch.tile(context, (1, 1, 144, 80))
pos_enc = self.positional_encoding.tile(len(context), 1, 1, 1)
context = torch.concat((context, pos_enc), dim=1)
query = self.query_encoder(context).permute(0, 2, 3, 1).reshape(BS, -1, TRANSFORMER_DIM)
x = x.reshape(BS, FPN_DIM, -1)
key = self.key_encoder(x).permute(0, 2, 1)
value = self.value_encoder(x).permute(0, 2, 1)
bev, weights = self.transformer(
query,
key,
value,
need_weights=False,
)
bev = bev.reshape(BS, 144, 80, TRANSFORMER_DIM)
bev = bev.permute(0, 3, 1, 2)
return bev
class VoxelDecoder(nn.Module):
def __init__(self):
super().__init__()
self.f4_encoder = nn.Sequential(
nn.Conv2d(2560, 512, 1),
nn.SiLU(),
nn.Conv2d(512, FPN_DIM, 1),
nn.SiLU(),
)
self.f3_encoder = nn.Sequential(
nn.Conv2d(1280, 256, 1),
nn.SiLU(),
nn.Conv2d(256, FPN_DIM, 1),
nn.SiLU(),
)
self.f2_encoder = nn.Sequential(
nn.Conv2d(640, FPN_DIM, 1),
nn.SiLU(),
)
self.f1_encoder = nn.Sequential(
nn.Conv2d(320, FPN_DIM, 1),
nn.SiLU(),
)
bifpns = [BiFPNBlock(FPN_DIM) for i in range(2)]
self.bifpn = nn.Sequential(*bifpns)
#self.transformer1 = TransformerBlock((104, 160))
#self.transformer2 = TransformerBlock((52, 80))
self.transformer3 = TransformerBlock((26, 40))
self.transformer4 = TransformerBlock((13, 20))
self.decoder = nn.Sequential(
nn.Conv2d(TRANSFORMER_DIM * 2, TRANSFORMER_DIM, 1),
nn.SiLU(),
nn.Conv2d(TRANSFORMER_DIM, 24, 1),
nn.SiLU(),
nn.Conv2d(24, 12, 1),
)
def forward(self, features):
f1_stack = []
f2_stack = []
f3_stack = []
f4_stack = []
for cam in CAMERAS:
f1, f2, f3, f4 = features[cam]
f1_stack.append(f1)
f2_stack.append(f2)
f3_stack.append(f3)
f4_stack.append(f4)
x1 = self.f1_encoder(torch.concat(f1_stack, dim=1))
x2 = self.f2_encoder(torch.concat(f2_stack, dim=1))
x3 = self.f3_encoder(torch.concat(f3_stack, dim=1))
x4 = self.f4_encoder(torch.concat(f4_stack, dim=1))
x1, x2, x3, x4 = self.bifpn((x1, x2, x3, x4))
#bev1 = self.transformer1(x1)
#bev2 = self.transformer2(x2)
bev3 = self.transformer3(x3)
bev4 = self.transformer4(x4)
bev = torch.concat((bev3, bev4), dim=1)
voxels = self.decoder(bev)
return voxels
class VoxelNet(nn.Module):
def __init__(self, freeze_cams=True):
super().__init__()
self.models = nn.ModuleDict({cam: CamEncoder() for cam in CAMERAS})
self.decoder = VoxelDecoder()
def freeze_cams(self, freeze=True):
# freeze the encoders
for model in self.models.values():
for param in model.parameters():
param.requires_grad = not freeze
def forward(self, imgs):
features = {name: self.models[name](img) for name, img in imgs.items()}
return self.decoder(features)
def normalize(tensor):
tensor = torch.from_numpy(tensor.astype(np.float32))
# convert to CHW
tensor = tensor.permute(2, 0, 1)
return transforms.functional.normalize(
tensor,
(18069, 24765, 18272),
(1739, 2487, 1844),
inplace=True,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment