Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Created August 9, 2019 09:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomwolf/eabef5ee63eef190d0e47e05568ca71c to your computer and use it in GitHub Desktop.
Save thomwolf/eabef5ee63eef190d0e47e05568ca71c to your computer and use it in GitHub Desktop.
Main forward pass for GPT-2
def forward(self, input_ids):
position_ids = torch.arange(0, input_ids.size(-1), dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
hidden_states = self.drop(hidden_states)
for block in self.h:
hidden_states = block(hidden_states)
hidden_states = self.ln_f(hidden_states)
return hidden_states
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment