Skip to content

Instantly share code, notes, and snippets.

@honglu2875
Last active October 31, 2023 09:52
Show Gist options
  • Save honglu2875/f3a1c78970ad055e758d0a9fa8e09e47 to your computer and use it in GitHub Desktop.
Save honglu2875/f3a1c78970ad055e758d0a9fa8e09e47 to your computer and use it in GitHub Desktop.
Aria scripts
"""Contains generation/sampling code"""
# This file contains code from https://github.com/facebookresearch/llama which
# is available under the following licence:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU
# General Public License version 3.
import torch
from typing import List
from aria.model import TransformerLM
from aria.tokenizer import Tokenizer
# TODO:
# - Enable sampling sequences longer than max_seq_len by truncating
# Some good settings:
# temp=0.85, top_p=0.9, cfg_gamma=1.4
@torch.autocast(device_type="cuda", dtype=torch.float16)
def interpolate_sample(
model: TransformerLM,
tokenizer: Tokenizer,
prompts: List[list],
alternative: List[list],
max_seq_len: int,
max_gen_len: int,
force_end=False,
temperature: float = 0.85,
top_p: float = 0.9,
cfg_gamma: float | None = 1.2,
alpha: float | None = 0.3,
):
"""Performs greedy (top_p) autoregressive sampling on a batch of prompts.
Args:
model (TransformerLM): Model to sample from.
tokenizer (Tokenizer): Tokenizer corresponding to model.
prompts (List[list]): A list of prompts to sample as a batch.
max_seq_len (int): Maximum sequence length supported by the model.
max_gen_len (int): Maximum desired sequence length of the samples.
temperature (float, optional): Sampling temperature. Defaults to 0.75.
top_p (float, optional): Parameter for top-p sampling. Defaults to 0.95.
Returns:
List[list]: The list of samples, decoded by the tokenizer.
"""
assert tokenizer.return_tensors is True, "tokenizer must return tensors."
model.eval()
pad_id = tokenizer.pad_id
eos_id = tokenizer.tok_to_id[tokenizer.eos_tok]
bsz = len(prompts)
min_prompt_size = min([len(t) for t in prompts])
max_prompt_size = max([len(t) for t in prompts])
total_len = min(max_seq_len, max_gen_len + max_prompt_size)
if cfg_gamma:
assert (
min_prompt_size == max_prompt_size
), "CFG not supported with varying prompts"
if force_end:
assert (
total_len - max_prompt_size > 130
), "prompt too long to use force_end=True"
print(
f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_gen_len}"
)
tokens = torch.full((bsz, total_len), pad_id).cuda()
alt_tokens = torch.full((bsz, total_len), pad_id).cuda()
alt_len = min(total_len, min(len(a) for a in alternative))
for idx, (unencoded_seq, alt_seq) in enumerate(zip(prompts, alternative)):
tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq)
alt_tokens[idx, : alt_len] = tokenizer.encode(alt_seq)[:alt_len]
dim_tok_inserted = [False for _ in range(bsz)]
input_text_mask = tokens != pad_id
start_pos = min_prompt_size
past_kv = None
alt_kv = None
_use_cache = True
with torch.inference_mode():
for cur_pos in range(start_pos, total_len):
token = tokens[:, :start_pos] if cur_pos == start_pos else tokens[:, cur_pos-1:cur_pos]
#token = tokens[:, :cur_pos]
logits, past_kv = model.forward(token, use_cache=_use_cache, past_kv=past_kv)
#logits = model.forward(token, use_cache=_use_cache, past_kv=past_kv)
logits = logits[:, -1, :]
coeff = (cur_pos - start_pos) / (total_len - start_pos) * cfg_gamma
if cfg_gamma and max_prompt_size < cur_pos:
alt_tok = alt_tokens[:, :start_pos] if cur_pos == start_pos else alt_tokens[:, cur_pos-1:cur_pos]
alt_logits, alt_kv = model.forward(alt_tok, use_cache=_use_cache, past_kv=alt_kv)
#uncond_logits = model.forward(tokens[:, :cur_pos], use_cache=_use_cache, past_kv=cfg_kv)
alt_logits = alt_logits[:, -1, :]
logits = logits + coeff * (alt_logits - logits)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
# Only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
# Insert dim tokens
if force_end and cur_pos >= total_len - 130:
for _idx in range(bsz):
if (
dim_tok_inserted[_idx] is False
and tokenizer.id_to_tok[next_token[_idx].item()][0] != "dur"
):
next_token[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok]
# Update dim_tok_inserted
for _idx in range(bsz):
if next_token[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]:
dim_tok_inserted[_idx] = True
tokens[:, cur_pos] = next_token
if cur_pos >= alt_len or (alpha is not None and coeff > alpha):
alt_tokens[:, cur_pos] = next_token
decoded = []
for idx, seq in enumerate(tokens.tolist()):
# Cut to max gen len
seq = seq[: len(prompts[idx]) + max_gen_len]
# Cut to eos tok if any
try:
seq = seq[: seq.index(eos_id)]
except ValueError:
pass
decoded.append(tokenizer.decode(seq))
return decoded
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
"""Includes (PyTorch) transformer model and config classes."""
import torch
import torch.utils.checkpoint
from torch import nn as nn
from torch.nn import functional as F
class ModelConfig:
def __init__(
self,
d_model: int,
n_heads: int,
n_layers: int,
ff_mult: int,
drop_p: float,
max_seq_len: int,
grad_checkpoint: bool,
):
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.ff_mult = ff_mult
self.drop_p = drop_p
self.max_seq_len = max_seq_len
self.grad_checkpoint = grad_checkpoint
def set_vocab_size(self, vocab_size: int):
self.vocab_size = vocab_size
# Taken from GPT-NeoX see:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
if device is None: # todo: maybe we don't need this...
device = "cuda" if torch.cuda.is_available() else None
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
#self.cos_cached = emb.cos().to(dtype)
#self.sin_cached = emb.sin().to(dtype)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat(
(-x2, x1), dim=x1.ndim - 1
) # dim=-1 triggers a bug in earlier torch versions
@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin, past_len: int = 0):
"""Returns tuple (xq, xk). Expects shape (s_len, b_sz, n_head, d_head)."""
cos = cos[past_len:past_len + q.size(0), None, None]
sin = sin[past_len:past_len + q.size(0), None, None]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (
rotate_half(k) * sin
)
class FusedEncoderBlock(nn.Module):
"""Transformer block using F.scaled_dot_product_attention().
This block has the following changes from a typical transformer encoder:
- Rotary embeddings are applied to the key/query matrices.
- Layer norm is applied before attention and feed forward, instead of
after.
- Keys arising from padding are masked during attention.
- GELU activation is used instead of ReLU.
Args:
model_config (ModelConfig): Model config settings.
"""
def __init__(self, model_config: ModelConfig):
super().__init__()
self.drop_p = model_config.drop_p
self.n_heads = model_config.n_heads
self.d_head = model_config.d_model // model_config.n_heads
self.max_seq_len = model_config.max_seq_len
# Positional embeddings
self.rotary_emb = RotaryEmbedding(self.d_head)
# Attention
self.mixed_qkv = nn.Linear(
in_features=model_config.d_model,
out_features=3 * model_config.d_model,
bias=False,
)
self.att_proj_linear = nn.Linear(
in_features=model_config.d_model,
out_features=model_config.d_model,
)
self.resid_dropout = nn.Dropout(model_config.drop_p)
# FF Layer
self.ff_dropout = nn.Dropout(model_config.drop_p)
self.ff_linear_1 = nn.Linear(
in_features=model_config.d_model,
out_features=model_config.d_model * model_config.ff_mult,
)
self.ff_linear_2 = nn.Linear(
in_features=model_config.d_model * model_config.ff_mult,
out_features=model_config.d_model,
)
self.ff_activation = nn.GELU()
# Pre layer norms
self.norm1 = nn.LayerNorm(model_config.d_model)
self.norm2 = nn.LayerNorm(model_config.d_model)
def forward(self, x: torch.Tensor, use_cache=False, past_kv=None):
att, kv = self._att_block(self.norm1(x), use_cache=use_cache, past_kv=past_kv)
x = x + att
x = x + self._ff_block(self.norm2(x))
return x, kv
def _att_block(self, x: torch.Tensor, use_cache=False, past_kv=None):
batch_size, seq_len, _ = x.shape
mixed_qkv = self.mixed_qkv(x)
xq, xk, xv = mixed_qkv.chunk(3, -1)
# Reshape for rotary embeddings
xq = xq.view(batch_size, seq_len, self.n_heads, self.d_head)
xk = xk.view(batch_size, seq_len, self.n_heads, self.d_head)
xv = xv.view(batch_size, seq_len, self.n_heads, self.d_head)
past_len = 0 if past_kv is None else past_kv[0].size(1)
# apply_rotary_post_emb expects: (s_len, b_sz, n_head, d_head)
cos, sin = self.rotary_emb(x=xv, seq_len=seq_len + past_len)
xq, xk = xq.transpose(0, 1), xk.transpose(0, 1)
xq, xk = apply_rotary_pos_emb(q=xq, k=xk, cos=cos, sin=sin, past_len=past_len)
xq, xk = xq.transpose(0, 1), xk.transpose(0, 1)
# xq, xk: (b_sz, s_len, n_head, d_head)
if past_kv is not None:
assert len(past_kv) == 2
xk = torch.concat([past_kv[0], xk], axis=1)
xv = torch.concat([past_kv[1], xv], axis=1)
kv = (xk, xv)
# Reshape for attention calculation: (b_sz, n_head, s_len, d_head)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# Required as we are not using a nn.Dropout layer
if self.training:
att_dropout = 0.1 # Bug?
else:
att_dropout = 0.0
# Using beta torch functionality (subject to change)
# See - https://shorturl.at/jtI17
if past_kv is None:
att = F.scaled_dot_product_attention(
query=xq,
key=xk,
value=xv,
dropout_p=att_dropout,
is_causal=True,
)
else:
assert xq.size(2) == 1
mask = torch.ones(1, xk.size(2), dtype=bool, device=xk.device)
att = F.scaled_dot_product_attention(
query=xq,
key=xk,
value=xv,
dropout_p=att_dropout,
is_causal=False,
attn_mask=mask,
)
# Reshape for out: (b_sz, s_len, n_head, d_head)
out = att.transpose(1, 2).contiguous()
out = out.view(batch_size, seq_len, self.n_heads * self.d_head)
return self.resid_dropout(self.att_proj_linear(out)), kv if use_cache else None
def _ff_block(self, x: torch.Tensor):
x = self.ff_linear_2(self.ff_activation(self.ff_linear_1(x)))
return self.ff_dropout(x)
class Transformer(nn.Module):
"""Transformer decoder with no language model head.
Args:
model_config (ModelConfig): Model config settings.
"""
def __init__(self, model_config: ModelConfig):
super().__init__()
self.model_config = model_config
self.tok_embeddings = nn.Embedding(
num_embeddings=model_config.vocab_size,
embedding_dim=model_config.d_model,
)
self.out_layer_norm = nn.LayerNorm(model_config.d_model)
self.encode_layers = nn.ModuleList()
for _ in range(model_config.n_layers):
self.encode_layers.append(FusedEncoderBlock(model_config))
def forward(self, src: torch.Tensor, use_cache=False, past_kv=None):
"""Forward pass of Transformer.
Args:
src (torch.tensor): Input to encoder block, of shape (batch_size,
seq_len, d_model).
Returns:
torch.tensor: Model outputs with shape (batch_size, seq_len,
d_model).
"""
hidden_states = self.tok_embeddings(src)
assert src.shape[1] <= self.model_config.max_seq_len, "Too long."
# NOTE: If you want to use gradient checkpointing then you must
# remove torch.compile from the train script as this is not currently
# supported.
# Implements gradient checkpoints on Encoder Layers.
if self.model_config.grad_checkpoint is True:
for layer in self.encode_layers:
def create_custom_forward(module):
def custom_forward(*args):
return module(*args)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
preserve_rng_state=True,
use_reentrant=True,
)
else:
new_past_kv = []
past_kv = [None] * len(self.encode_layers) if past_kv is None else past_kv
for layer, _kv in zip(self.encode_layers, past_kv):
hidden_states, kv = layer(hidden_states, use_cache=use_cache, past_kv=_kv)
new_past_kv.append(kv)
return self.out_layer_norm(hidden_states), new_past_kv if use_cache else None
class TransformerLM(nn.Module):
"""Transformer decoder with head for language modelling.
Args:
model_config (ModelConfig): Model config settings.
"""
def __init__(self, model_config: ModelConfig):
super().__init__()
self.max_seq_len = model_config.max_seq_len
self.model = Transformer(model_config)
self.lm_head = nn.Linear(
model_config.d_model, model_config.vocab_size, bias=False
)
def forward(self, src: torch.Tensor, use_cache=False, past_kv=None):
"""Forward pass of Transformer decoder with LM head.
Args:
src (torch.tensor): Input to encoder block, of shape (batch_size,
seq_len, d_model).
Returns:
torch.tensor: Forward pass of src through Transformer and LM head.
Has shape (batch_size, seq_len, vocab_size).
"""
hidden, past_kv = self.model(src, use_cache=use_cache, past_kv=past_kv)
logits = self.lm_head(hidden)
if use_cache:
return logits, past_kv
else:
return logits
"""Contains generation/sampling code"""
# This file contains code from https://github.com/facebookresearch/llama which
# is available under the following licence:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU
# General Public License version 3.
import torch
from typing import List
from aria.model import TransformerLM
from aria.tokenizer import Tokenizer
# TODO:
# - Enable sampling sequences longer than max_seq_len by truncating
# Some good settings:
# temp=0.85, top_p=0.9, cfg_gamma=1.4
@torch.autocast(device_type="cuda", dtype=torch.float16)
def greedy_sample(
model: TransformerLM,
tokenizer: Tokenizer,
prompts: List[list],
max_seq_len: int,
max_gen_len: int,
force_end=False,
temperature: float = 0.85,
top_p: float = 0.9,
cfg_gamma: float | None = 1.2,
):
"""Performs greedy (top_p) autoregressive sampling on a batch of prompts.
Args:
model (TransformerLM): Model to sample from.
tokenizer (Tokenizer): Tokenizer corresponding to model.
prompts (List[list]): A list of prompts to sample as a batch.
max_seq_len (int): Maximum sequence length supported by the model.
max_gen_len (int): Maximum desired sequence length of the samples.
temperature (float, optional): Sampling temperature. Defaults to 0.75.
top_p (float, optional): Parameter for top-p sampling. Defaults to 0.95.
Returns:
List[list]: The list of samples, decoded by the tokenizer.
"""
assert tokenizer.return_tensors is True, "tokenizer must return tensors."
model.eval()
pad_id = tokenizer.pad_id
eos_id = tokenizer.tok_to_id[tokenizer.eos_tok]
bsz = len(prompts)
min_prompt_size = min([len(t) for t in prompts])
max_prompt_size = max([len(t) for t in prompts])
total_len = min(max_seq_len, max_gen_len + max_prompt_size)
if cfg_gamma:
assert (
min_prompt_size == max_prompt_size
), "CFG not supported with varying prompts"
if force_end:
assert (
total_len - max_prompt_size > 130
), "prompt too long to use force_end=True"
print(
f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_gen_len}"
)
tokens = torch.full((bsz, total_len), pad_id).cuda()
for idx, unencoded_seq in enumerate(prompts):
tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq)
dim_tok_inserted = [False for _ in range(bsz)]
input_text_mask = tokens != pad_id
start_pos = min_prompt_size
past_kv = None
cfg_kv = None
_use_cache = True
with torch.inference_mode():
for cur_pos in range(start_pos, total_len):
token = tokens[:, :start_pos] if cur_pos == start_pos else tokens[:, cur_pos-1:cur_pos]
#token = tokens[:, :cur_pos]
logits, past_kv = model.forward(token, use_cache=_use_cache, past_kv=past_kv)
#logits = model.forward(token, use_cache=_use_cache, past_kv=past_kv)
logits = logits[:, -1, :]
if cfg_gamma and max_prompt_size < cur_pos:
uncond_logits, cfg_kv = model.forward(tokens[:, cur_pos-1:cur_pos], use_cache=_use_cache, past_kv=cfg_kv)
#uncond_logits = model.forward(tokens[:, :cur_pos], use_cache=_use_cache, past_kv=cfg_kv)
uncond_logits = uncond_logits[:, -1, :]
logits = uncond_logits + cfg_gamma * (logits - uncond_logits)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
# Only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
# Insert dim tokens
if force_end and cur_pos >= total_len - 130:
for _idx in range(bsz):
if (
dim_tok_inserted[_idx] is False
and tokenizer.id_to_tok[next_token[_idx].item()][0] != "dur"
):
next_token[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok]
# Update dim_tok_inserted
for _idx in range(bsz):
if next_token[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]:
dim_tok_inserted[_idx] = True
tokens[:, cur_pos] = next_token
decoded = []
for idx, seq in enumerate(tokens.tolist()):
# Cut to max gen len
seq = seq[: len(prompts[idx]) + max_gen_len]
# Cut to eos tok if any
try:
seq = seq[: seq.index(eos_id)]
except ValueError:
pass
decoded.append(tokenizer.decode(seq))
return decoded
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment