Skip to content

Instantly share code, notes, and snippets.

@zhihou7
Last active August 24, 2022 05:50
Show Gist options
  • Save zhihou7/b8a5e3599eec91bed58e723924c6785e to your computer and use it in GitHub Desktop.
Save zhihou7/b8a5e3599eec91bed58e723924c6785e to your computer and use it in GitHub Desktop.
BatchFormerV2 for DETR
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None, bf=None, bf_idx =0, insert_idx=[], use_checkpoint=False):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
if type(bf) != torch.nn.ModuleList and bf is not None:
self.bf = [bf]*num_layers
else:
self.bf = bf
self.insert_idx = insert_idx
self.bf_idx = bf_idx
self.use_checkpoint = use_checkpoint
def forward(self, src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
output = src
for i, layer in enumerate(self.layers):
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
if i in self.insert_idx and self.bf is not None and self.bf_idx == 3 and self.training:
L, B, C = output.shape
old_output = output
if i != self.insert_idx[0]:
old_output = output[:, :B//2, :]
output = output[:, B//2:, :]
# old_output = output[:, :len(output)//2, :]
# output = output[:, len(output)//2:, :]
# the original batches
output = self.bf[i](torch.transpose(output, 1, 0))
output = torch.transpose(output, 1, 0)
output = torch.cat([old_output, output], dim=1)
if i == self.insert_idx[0]:
pos = torch.cat([pos, pos], dim=1)
src_key_padding_mask = torch.cat([src_key_padding_mask, src_key_padding_mask], dim=0)
if self.norm is not None:
output = self.norm(output)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment