Skip to content

Instantly share code, notes, and snippets.

View zhihou7's full-sized avatar

Zhi Hou zhihou7

View GitHub Profile
@zhihou7
zhihou7 / transformer.py
Last active August 24, 2022 05:50
BatchFormerV2 for DETR
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None, bf=None, bf_idx =0, insert_idx=[], use_checkpoint=False):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
if type(bf) != torch.nn.ModuleList and bf is not None:
self.bf = [bf]*num_layers
else: