Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active August 9, 2019 09:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomwolf/f357aab1fac8f05288cadef5e3145c4f to your computer and use it in GitHub Desktop.
Save thomwolf/f357aab1fac8f05288cadef5e3145c4f to your computer and use it in GitHub Desktop.
GPT-2 main model class
class GPT2Model(nn.Module):
def __init__(self, config):
super(GPT2Model, self).__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.apply(self.init_weights)
def init_weights(self):
raise NotImplementedError
def forward(self, input_ids, past=None):
raise NotImplementedError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment