Skip to content

Instantly share code, notes, and snippets.

@zcbenz
Last active July 15, 2024 05:28
Show Gist options
  • Save zcbenz/c49441cfcc5a49fab5d897c77a82a917 to your computer and use it in GitHub Desktop.
Save zcbenz/c49441cfcc5a49fab5d897c77a82a917 to your computer and use it in GitHub Desktop.
mlx-nan-bug
{
"model_type": "llama",
"hidden_size": 288,
"intermediate_size": 768,
"num_hidden_layers": 6,
"num_attention_heads": 6,
"num_key_value_heads": 6,
"rms_norm_eps": 1e-05,
"vocab_size": 48588
}
import inspect
import mlx.core as mx
import mlx.nn as nn
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
@dataclass
class BaseModelArgs:
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers
import json
import time
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from transformers import AutoTokenizer
from datasets import load_dataset
from model import Model, ModelArgs
class Dataset:
"""
Light-weight wrapper to hold a dataset.
"""
def __init__(self, data, text_key = "text"):
self._text_key = text_key
self._data = data
def __getitem__(self, idx):
return self._data[idx][self._text_key]
def __len__(self):
if self._data is None:
return 0
return len(self._data)
def iterate_batches(dataset, tokenizer, batch_size, context_size):
x_batch = []
y_batch = []
for text in dataset:
tokens = tokenizer.encode(text)
for i in range(0, len(tokens) - 1, context_size):
length = min(context_size, len(tokens) - i - 1)
# If the batch's length is less than context_size, fill it with eos_token.
paddings = []
if length < context_size:
paddings = [tokenizer.eos_token_id] * (context_size - length)
x_batch.append(tokens[i : i + length] + paddings)
y_batch.append(tokens[i + 1 : i + 1 + length] + paddings)
while len(x_batch) >= batch_size:
yield x_batch[:batch_size], y_batch[:batch_size]
x_batch, y_batch = x_batch[batch_size:], y_batch[batch_size:]
batch_size = 32
context_size = 256
max_iterations = 1000
with open('config.json', 'r') as f:
config = json.load(f)
model = Model(ModelArgs.from_dict(config))
tokenizer = AutoTokenizer.from_pretrained('mlx-community/Meta-Llama-3-8B-Instruct-8bit')
dataset = Dataset(load_dataset('Chat-Error/tinystories-gpt4', split='train'))
def loss_fn(model, x, y):
logits, cache = model(x)
losses = nn.losses.cross_entropy(logits, y)
return mx.mean(losses)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.AdamW(1e-3)
for it, (x, y) in zip(range(1, max_iterations + 1),
iterate_batches(dataset, tokenizer, batch_size, context_size)):
x = mx.array(x)
y = mx.array(y)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.state, optimizer.state)
print('Iter', it, 'Loss', loss.item())
if mx.isnan(loss):
exit(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment