Skip to content

Instantly share code, notes, and snippets.

View berlino's full-sized avatar
🏊‍♂️
Drown

Bailin berlino

🏊‍♂️
Drown
View GitHub Profile
@berlino
berlino / ngram-block.py
Created January 28, 2024 23:28
ngram for icll
class NgramBlock(nn.Module):
requires_input_ids = True
def __init__(self, config, ngram):
"""
parameter size 4d^2
"""
super().__init__()
self.ln_1 = RMSNorm(config.d_model, eps=1e-5)
import torch
if __name__ == "__main__":
N, d = 128, 256
dtype = torch.float32
A = torch.randn((N, N), dtype=dtype).cuda().requires_grad_(True)
p = torch.randn((N, ), dtype=dtype).uniform_(0.1, 0.9).cuda().requires_grad_(True)
o1 = A @ p
o1.sum().backward()