Skip to content

Instantly share code, notes, and snippets.

@scturtle
Created May 20, 2024 09:43
Show Gist options
  • Save scturtle/c5727018836565183dd6bcc984d458d4 to your computer and use it in GitHub Desktop.
Save scturtle/c5727018836565183dd6bcc984d458d4 to your computer and use it in GitHub Desktop.
llama3 in numpy
import numpy as np
class ModelArgs:
dim = 288
n_layers = 6
n_heads = 6
norm_eps = 1e-6
def build_cos_sin_cache(head_dim, seq_len, base=10000):
theta = 1. / (base ** (np.arange(0, head_dim, 2, dtype=np.float32) / head_dim))
seq_idx = np.arange(seq_len, dtype=np.float32)
idx_theta = np.outer(seq_idx, theta)
return np.cos(idx_theta), np.sin(idx_theta)
cos_cached, sin_cached = build_cos_sin_cache(ModelArgs.dim // ModelArgs.n_heads, seq_len=256)
def rope(x, start_pos):
seq_len = x.shape[1]
r = np.zeros_like(x)
cos = cos_cached[start_pos: start_pos + seq_len][None, :, None, :]
sin = sin_cached[start_pos: start_pos + seq_len][None, :, None, :]
r[:, :, :, 0::2] = x[:, :, :, 0::2] * cos - x[:, :, :, 1::2] * sin
r[:, :, :, 1::2] = x[:, :, :, 1::2] * cos + x[:, :, :, 0::2] * sin
return r
def softmax(x):
e = np.exp(x - np.max(x, axis=-1, keepdims=True))
return e / np.sum(e, axis=-1, keepdims=True)
def silu(x):
return x * (1 / (1 + np.exp(-x)))
def ffn(x, up_wgt, gate_wgt, down_wgt):
return (silu(x @ gate_wgt) * (x @ up_wgt)) @ down_wgt
def rmsnorm(x, eps=ModelArgs.norm_eps):
return x / np.sqrt(np.mean(x ** 2, axis=-1, keepdims=True) + eps)
def attn(x, start_pos, q_wgt, k_wgt, v_wgt, o_wgt, cache):
q = x @ q_wgt
k = x @ k_wgt
v = x @ v_wgt
B, L, d = x.shape
q = q.reshape((B, L, ModelArgs.n_heads, d // ModelArgs.n_heads))
k = k.reshape((B, L, ModelArgs.n_heads, d // ModelArgs.n_heads))
v = v.reshape((B, L, ModelArgs.n_heads, d // ModelArgs.n_heads))
q = rope(q, start_pos)
k = rope(k, start_pos)
if cache:
k_cache, v_cache = cache
k = np.concatenate([k_cache, k], axis=1)
v = np.concatenate([v_cache, v], axis=1)
cache[:] = [k, v]
n_rep = q.shape[-2] // k.shape[-2]
k = np.repeat(k, n_rep, axis=-2)
v = np.repeat(v, n_rep, axis=-2)
x = np.einsum('...qhd,...khd->...hqk', q, k)
if L > 1:
mask = (1 - np.tri(x.shape[-1], dtype=x.dtype)) * -1e10
else:
mask = 0
x = softmax(x * q.shape[-1] ** -0.5 + mask)
x = np.einsum('...hqk,...khd->...qhd', x, v)
x = x.reshape(x.shape[:-2] + (-1,))
return x @ o_wgt
def block(x, start_pos, layer_id, weights, cache):
rms_wgt_in = weights[f"model.layers.{layer_id}.input_layernorm.weight"]
q_wgt = weights[f"model.layers.{layer_id}.self_attn.q_proj.weight"]
k_wgt = weights[f"model.layers.{layer_id}.self_attn.k_proj.weight"]
v_wgt = weights[f"model.layers.{layer_id}.self_attn.v_proj.weight"]
o_wgt = weights[f"model.layers.{layer_id}.self_attn.o_proj.weight"]
rms_wgt_out = weights[f"model.layers.{layer_id}.post_attention_layernorm.weight"]
up_wgt = weights[f"model.layers.{layer_id}.mlp.up_proj.weight"]
gate_wgt = weights[f"model.layers.{layer_id}.mlp.gate_proj.weight"]
down_wgt = weights[f"model.layers.{layer_id}.mlp.down_proj.weight"]
norm_x = rmsnorm(x) * rms_wgt_in
x += attn(norm_x, start_pos, q_wgt, k_wgt, v_wgt, o_wgt, cache)
norm_x = rmsnorm(x) * rms_wgt_out
x += ffn(norm_x, up_wgt, gate_wgt, down_wgt)
return x
def llama3(x, start_pos, weights, caches):
x = weights["model.embed_tokens.weight"][x]
for i in range(ModelArgs.n_layers):
x = block(x, start_pos, layer_id=i, weights=weights, cache=caches[i])
x = rmsnorm(x) * weights["model.norm.weight"]
return x @ weights["lm_head.weight"]
def main():
from tokenizer import Tokenizer
tokenizer = Tokenizer("./tokenizer.model.np")
weights = dict(np.load("./stories15M.model.npz"))
for k in weights:
if k.endswith('proj.weight') or k == "lm_head.weight":
weights[k] = weights[k].T
prompt = "I have a dream"
print(f"{prompt}", end="", flush=True)
x = np.array([tokenizer.encode(prompt)])
caches = [[] for _ in range(ModelArgs.n_layers)]
for start_pos in range(x.shape[1], 56):
start_pos = 0 if not caches[0] else start_pos
logits = llama3(x, start_pos, weights, caches)
x = np.argmax(logits[:, -1, :], axis=-1, keepdims=True)
print(tokenizer.decode(x[0]), end="", flush=True)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment