Skip to content

Instantly share code, notes, and snippets.

@Guitaricet
Last active April 30, 2020 19:46
Show Gist options
  • Save Guitaricet/810fcdadd94787ef9f1f01387f59f463 to your computer and use it in GitHub Desktop.
Save Guitaricet/810fcdadd94787ef9f1f01387f59f463 to your computer and use it in GitHub Desktop.
Simple transformer encoder layer
class TransformerLayer(nn.Module):
def __init__(self, dim=496, heads=4, ffn_dim=1984):
super().__init__()
self.attention = Attention(dim, heads=heads)
self.fc = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.ReLU(), # chose your favorite nonlinearity here
nn.Linear(ffn_dim, dim),
)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
_, _, dim = x.shape
residual = x
x = self.attention(x)
x = self.norm(x + residual)
residual = x
x = self.fc(x)
x = self.norm(x + residual)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment