Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Created August 8, 2019 18:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomwolf/190ab88516199122561c4de737f12a45 to your computer and use it in GitHub Desktop.
Save thomwolf/190ab88516199122561c4de737f12a45 to your computer and use it in GitHub Desktop.
GPT-2 PyTorch block module
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__()
nx = config.n_embd
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
def forward(self, x):
a = self.attn(self.ln_1(x))
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment