Skip to content

Instantly share code, notes, and snippets.

@nickcdryan
Created May 7, 2024 16:51
Show Gist options
  • Save nickcdryan/08c059ec3deb3ef2aca881bdc4409631 to your computer and use it in GitHub Desktop.
Save nickcdryan/08c059ec3deb3ef2aca881bdc4409631 to your computer and use it in GitHub Desktop.
class WeightedSkipTransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
######### Skip weight connection
self.identity_skip = nn.Parameter(torch.zeros(1))
def forward(
self, x: torch.Tensor,
rotary_emb_fn
) -> torch.Tensor:
r = x + self.attention.forward(self.attention_norm(x), rotary_emb_fn)
h = self.feed_forward.forward(self.ffn_norm(x))
######### Apply skip weight skip connection
x = x * self.identity_skip
out = x + h + r
return out
class WeightedSkipTransformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList(
[WeightedSkipTransformerBlock(args=args) for _ in range(args.n_layers)]
)
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
self.rotary_emb_fn = RotaryEmbedding(dim = args.head_dim // 2)
def forward(
self,
input_ids: torch.Tensor,
labels = None,
) -> torch.Tensor:
h = self.tok_embeddings(input_ids)
for layer_id, layer in enumerate(self.layers):
h = layer(h, self.rotary_emb_fn)
# If inference:
if labels is None:
return self.output(self.norm(h)).float()
loss = F.cross_entropy(self.output(self.norm(h)).float().transpose(-1,-2), labels)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment