Skip to content

Instantly share code, notes, and snippets.

@Guitaricet
Last active April 30, 2020 19:46
Show Gist options
  • Save Guitaricet/bac30c97f3acac12a5bf404d4d26c1b7 to your computer and use it in GitHub Desktop.
Save Guitaricet/bac30c97f3acac12a5bf404d4d26c1b7 to your computer and use it in GitHub Desktop.
Lite transformer encoder layer
class LiteTransformerLayer(nn.Module):
def __init__(self, dim=496, heads=4, kernel_size=4):
super().__init__()
assert dim % 2 == 0
self.attention = Attention(dim // 2, heads=heads)
self.cnn = LightweightConv(dim // 2, kernel=kernel_size) # or Dynamic conv
self.fc = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(), # chose your favorite nonlinearity here
nn.Linear(dim, dim),
)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
_, _, dim = x.shape
residual = x
x1 = self.attention(x[:dim // 2])
x2 = self.cnn(x[dim // 2:])
x = torch.cat([x1, x2], dim=-1)
x = self.norm(x + residual)
residual = x
x = self.fc(x)
x = self.norm(x + residual)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment