Created
June 8, 2020 17:06
-
-
Save priya-dwivedi/192a8a8bd8324454bc20c85f46b5a243 to your computer and use it in GitHub Desktop.
DETR inference block
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DETRdemo(nn.Module): | |
""" | |
Demo DETR implementation. | |
Demo implementation of DETR in minimal number of lines, with the | |
following differences wrt DETR in the paper: | |
* learned positional encoding (instead of sine) | |
* positional encoding is passed at input (instead of attention) | |
* fc bbox predictor (instead of MLP) | |
The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100. | |
Only batch size 1 supported. | |
""" | |
def __init__(self, num_classes, hidden_dim=256, nheads=8, | |
num_encoder_layers=6, num_decoder_layers=6): | |
super().__init__() | |
# create ResNet-50 backbone | |
self.backbone = resnet50() | |
del self.backbone.fc | |
# create conversion layer | |
self.conv = nn.Conv2d(2048, hidden_dim, 1) | |
# create a default PyTorch transformer | |
self.transformer = nn.Transformer( | |
hidden_dim, nheads, num_encoder_layers, num_decoder_layers) | |
# prediction heads, one extra class for predicting non-empty slots | |
# note that in baseline DETR linear_bbox layer is 3-layer MLP | |
self.linear_class = nn.Linear(hidden_dim, num_classes + 1) | |
self.linear_bbox = nn.Linear(hidden_dim, 4) | |
# output positional encodings (object queries) | |
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) | |
# spatial positional encodings | |
# note that in baseline DETR we use sine positional encodings | |
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) | |
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) | |
def forward(self, inputs): | |
# propagate inputs through ResNet-50 up to avg-pool layer | |
x = self.backbone.conv1(inputs) | |
x = self.backbone.bn1(x) | |
x = self.backbone.relu(x) | |
x = self.backbone.maxpool(x) | |
x = self.backbone.layer1(x) | |
x = self.backbone.layer2(x) | |
x = self.backbone.layer3(x) | |
x = self.backbone.layer4(x) | |
# convert from 2048 to 256 feature planes for the transformer | |
h = self.conv(x) | |
# construct positional encodings | |
H, W = h.shape[-2:] | |
pos = torch.cat([ | |
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), | |
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), | |
], dim=-1).flatten(0, 1).unsqueeze(1) | |
# propagate through the transformer | |
h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1), | |
self.query_pos.unsqueeze(1)).transpose(0, 1) | |
# finally project transformer outputs to class labels and bounding boxes | |
return {'pred_logits': self.linear_class(h), | |
'pred_boxes': self.linear_bbox(h).sigmoid()} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment