Skip to content

Instantly share code, notes, and snippets.

@d4l3k

d4l3k/models.py Secret

Created January 6, 2022 07:49
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save d4l3k/d664de8f68607992edc7e09c1991d131 to your computer and use it in GitHub Desktop.
Save d4l3k/d664de8f68607992edc7e09c1991d131 to your computer and use it in GitHub Desktop.
import math
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
TRANSFORMER_DIM = 120
FPN_DIM = 128
CAMERAS = (
"main",
"leftrepeater",
"rightrepeater",
"leftpillar",
"rightpillar",
)
class CamEncoder(nn.Module):
def __init__(self):
super().__init__()
self.model = models.resnet18(pretrained=False, progress=True)
state_dict = torch.load(
"../../monodepth2/models/tesla18/encoder.pth",
map_location=torch.device('cpu'),
)
filtered_dict = {}
PREFIX = "module.encoder."
for name, param in state_dict.items():
if PREFIX not in name:
continue
new_name = name[len(PREFIX) :]
filtered_dict[new_name] = param
self.model.load_state_dict(filtered_dict)
def forward(self, x):
# adapted from torchvision.models.Resnet
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
f1 = self.model.layer1(x)
f2 = self.model.layer2(f1)
f3 = self.model.layer3(f2)
f4 = self.model.layer4(f3)
return f1, f2, f3, f4
class DepthwiseConvBlock(nn.Module):
"""
Depthwise seperable convolution.
From https://github.com/tristandb/EfficientDet-PyTorch/blob/master/bifpn.py
LGPL licensed
"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, freeze_bn=False):
super(DepthwiseConvBlock,self).__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride,
padding, dilation, groups=in_channels, bias=False)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=1, padding=0, dilation=1, groups=1, bias=False)
self.bn = nn.BatchNorm2d(out_channels, momentum=0.9997, eps=4e-5)
self.act = nn.SiLU()
def forward(self, inputs):
x = self.depthwise(inputs)
x = self.pointwise(x)
x = self.bn(x)
return self.act(x)
class BiFPNBlock(nn.Module):
"""
Bi-directional Feature Pyramid Network
From https://github.com/tristandb/EfficientDet-PyTorch/blob/master/bifpn.py
LGPL licensed
"""
def __init__(self, feature_size=64, epsilon=0.0001):
super(BiFPNBlock, self).__init__()
self.epsilon = epsilon
self.p3_td = DepthwiseConvBlock(feature_size, feature_size)
self.p4_td = DepthwiseConvBlock(feature_size, feature_size)
self.p6_td = DepthwiseConvBlock(feature_size, feature_size)
self.p4_out = DepthwiseConvBlock(feature_size, feature_size)
self.p6_out = DepthwiseConvBlock(feature_size, feature_size)
self.p7_out = DepthwiseConvBlock(feature_size, feature_size)
# TODO: Init weights
self.w1 = nn.Parameter(torch.Tensor(2, 4))
self.w1_relu = nn.SiLU()
self.w2 = nn.Parameter(torch.Tensor(3, 4))
self.w2_relu = nn.SiLU()
def forward(self, inputs):
p3_x, p4_x, p6_x, p7_x = inputs
# Calculate Top-Down Pathway
w1 = self.w1_relu(self.w1)
w1 /= torch.sum(w1, dim=0) + self.epsilon
w2 = self.w2_relu(self.w2)
w2 /= torch.sum(w2, dim=0) + self.epsilon
p7_td = p7_x
p6_td = self.p6_td(w1[0, 0] * p6_x + w1[1, 0] * F.interpolate(p7_td, scale_factor=2))
p4_td = self.p4_td(w1[0, 2] * p4_x + w1[1, 2] * F.interpolate(p6_td, scale_factor=2))
p3_td = self.p3_td(w1[0, 3] * p3_x + w1[1, 3] * F.interpolate(p4_td, scale_factor=2))
# Calculate Bottom-Up Pathway
p3_out = p3_td
p4_out = self.p4_out(w2[0, 0] * p4_x + w2[1, 0] * p4_td + w2[2, 0] * nn.Upsample(scale_factor=0.5)(p3_out))
p6_out = self.p6_out(w2[0, 2] * p6_x + w2[1, 2] * p6_td + w2[2, 2] *
nn.Upsample(scale_factor=0.5)(p4_out))
p7_out = self.p7_out(w2[0, 3] * p7_x + w2[1, 3] * p7_td + w2[2, 3] * nn.Upsample(scale_factor=0.5)(p6_out))
return [p3_out, p4_out, p6_out, p7_out]
class TransformerBlock(nn.Module):
def __init__(self, dims):
super().__init__()
self.transformer = nn.MultiheadAttention(
embed_dim=TRANSFORMER_DIM,
num_heads=12,
batch_first=True,
)
self.context_encoder = nn.Sequential(
nn.MaxPool2d(dims),
)
positional_encoding = torch.zeros((1, 4, 144, 80))
z_range = torch.arange(0, 144) / 143 * 2 * math.pi
x_range = torch.arange(0, 80) / 79 * 2 * math.pi
positional_encoding[0, 0, :, :] = torch.sin(z_range).unsqueeze(1)
positional_encoding[0, 1, :, :] = torch.cos(z_range).unsqueeze(1)
positional_encoding[0, 2, :, :] = torch.sin(x_range).unsqueeze(0)
positional_encoding[0, 3, :, :] = torch.cos(x_range).unsqueeze(0)
self.register_buffer(
"positional_encoding", positional_encoding, persistent=False
)
self.query_encoder = nn.Sequential(
nn.Conv2d(FPN_DIM+4, FPN_DIM, 1),
nn.SiLU(),
nn.Conv2d(FPN_DIM, TRANSFORMER_DIM, 1),
)
self.key_encoder = nn.Sequential(
nn.Conv1d(FPN_DIM, TRANSFORMER_DIM, 1),
)
self.value_encoder = nn.Sequential(
nn.Conv1d(FPN_DIM, TRANSFORMER_DIM, 1),
)
def forward(self, x):
BS = len(x)
context = self.context_encoder(x)
context = torch.tile(context, (1, 1, 144, 80))
pos_enc = self.positional_encoding.tile(len(context), 1, 1, 1)
context = torch.concat((context, pos_enc), dim=1)
query = self.query_encoder(context).permute(0, 2, 3, 1).reshape(BS, -1, TRANSFORMER_DIM)
x = x.reshape(BS, FPN_DIM, -1)
key = self.key_encoder(x).permute(0, 2, 1)
value = self.value_encoder(x).permute(0, 2, 1)
bev, weights = self.transformer(
query,
key,
value,
need_weights=False,
)
bev = bev.reshape(BS, 144, 80, TRANSFORMER_DIM)
bev = bev.permute(0, 3, 1, 2)
return bev
class VoxelDecoder(nn.Module):
def __init__(self):
super().__init__()
self.f4_encoder = nn.Sequential(
nn.Conv2d(2560, 512, 1),
nn.SiLU(),
nn.Conv2d(512, FPN_DIM, 1),
nn.SiLU(),
)
self.f3_encoder = nn.Sequential(
nn.Conv2d(1280, 256, 1),
nn.SiLU(),
nn.Conv2d(256, FPN_DIM, 1),
nn.SiLU(),
)
self.f2_encoder = nn.Sequential(
nn.Conv2d(640, FPN_DIM, 1),
nn.SiLU(),
)
self.f1_encoder = nn.Sequential(
nn.Conv2d(320, FPN_DIM, 1),
nn.SiLU(),
)
bifpns = [BiFPNBlock(FPN_DIM) for i in range(2)]
self.bifpn = nn.Sequential(*bifpns)
#self.transformer1 = TransformerBlock((104, 160))
#self.transformer2 = TransformerBlock((52, 80))
self.transformer3 = TransformerBlock((26, 40))
self.transformer4 = TransformerBlock((13, 20))
self.decoder = nn.Sequential(
nn.Conv2d(TRANSFORMER_DIM * 2, TRANSFORMER_DIM, 1),
nn.SiLU(),
nn.Conv2d(TRANSFORMER_DIM, 24, 1),
nn.SiLU(),
nn.Conv2d(24, 12, 1),
)
def forward(self, features):
f1_stack = []
f2_stack = []
f3_stack = []
f4_stack = []
for cam in CAMERAS:
f1, f2, f3, f4 = features[cam]
f1_stack.append(f1)
f2_stack.append(f2)
f3_stack.append(f3)
f4_stack.append(f4)
x1 = self.f1_encoder(torch.concat(f1_stack, dim=1))
x2 = self.f2_encoder(torch.concat(f2_stack, dim=1))
x3 = self.f3_encoder(torch.concat(f3_stack, dim=1))
x4 = self.f4_encoder(torch.concat(f4_stack, dim=1))
x1, x2, x3, x4 = self.bifpn((x1, x2, x3, x4))
#bev1 = self.transformer1(x1)
#bev2 = self.transformer2(x2)
bev3 = self.transformer3(x3)
bev4 = self.transformer4(x4)
bev = torch.concat((bev3, bev4), dim=1)
voxels = self.decoder(bev)
return voxels
class VoxelNet(nn.Module):
def __init__(self, freeze_cams=True):
super().__init__()
self.models = nn.ModuleDict({cam: CamEncoder() for cam in CAMERAS})
self.decoder = VoxelDecoder()
def freeze_cams(self, freeze=True):
# freeze the encoders
for model in self.models.values():
for param in model.parameters():
param.requires_grad = not freeze
def forward(self, imgs):
features = {name: self.models[name](img) for name, img in imgs.items()}
return self.decoder(features)
def normalize(tensor):
tensor = torch.from_numpy(tensor.astype(np.float32))
# convert to CHW
tensor = tensor.permute(2, 0, 1)
return transforms.functional.normalize(
tensor,
(18069, 24765, 18272),
(1739, 2487, 1844),
inplace=True,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment