Skip to content

Instantly share code, notes, and snippets.

@berlino
Created January 28, 2024 23:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save berlino/f5226e051b571322705365cf39399031 to your computer and use it in GitHub Desktop.
Save berlino/f5226e051b571322705365cf39399031 to your computer and use it in GitHub Desktop.
ngram for icll
class NgramBlock(nn.Module):
requires_input_ids = True
def __init__(self, config, ngram):
"""
parameter size 4d^2
"""
super().__init__()
self.ln_1 = RMSNorm(config.d_model, eps=1e-5)
self.attn = Ngram(config, ngram)
self.ln_2 = RMSNorm(config.d_model, eps=1e-5)
mlp_hidden = config.d_model
self.mlp = nn.Sequential(
nn.Linear(config.d_model, mlp_hidden),
nn.SiLU(),
nn.Linear(mlp_hidden, config.d_model),
)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def forward(self, x, input_ids):
# attention/rnn
x_att = self.attn(self.ln_1(x), input_ids)
x = x + self.resid_dropout(x_att)
# ffn
x_mlp = self.mlp(self.ln_2(x))
x = x + self.resid_dropout(x_mlp)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment