Skip to content

Instantly share code, notes, and snippets.

@matthewchung74
Created February 5, 2021 23:43
Show Gist options
  • Save matthewchung74/c5807b4840efca84a063426aa9a77b34 to your computer and use it in GitHub Desktop.
Save matthewchung74/c5807b4840efca84a063426aa9a77b34 to your computer and use it in GitHub Desktop.
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment