Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active June 15, 2023 21:52
Show Gist options
  • Save crowsonkb/6856f8bdd0cf713e2a6315cdaa8d2c53 to your computer and use it in GitHub Desktop.
Save crowsonkb/6856f8bdd0cf713e2a6315cdaa8d2c53 to your computer and use it in GitHub Desktop.
"""Stochastic beam search.
Implements "Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for
Sampling Sequences Without Replacement" (https://arxiv.org/abs/1903.06059)"""
import math
import torch
def log1mexp(a):
a1 = torch.log(-torch.expm1(a))
a2 = torch.log1p(-torch.exp(a))
return torch.where(a > -math.log(2), a1, a2)
def shift_gumbel(g, t):
z = torch.max(g, dim=-1, keepdim=True).values
v = t[..., None] - g + log1mexp(g - z)
return t[..., None] - v.relu() - torch.nn.functional.softplus(-v.abs())
def stochastic_beam_search(model, input_ids, n_tokens, beam_width, temperature=1.0):
assert input_ids.shape[0] == 1
device = input_ids.device
past_key_values = None
# Initialize beam
input_ids = input_ids.repeat(beam_width, 1)
phi_s = torch.zeros([1], device=device)
g_phi_s = torch.zeros([1], device=device)
cur_beam_width = 1
for _ in range(n_tokens):
input_ids_in = input_ids if past_key_values is None else input_ids[:, -1:]
with torch.no_grad():
model_output = model(
input_ids_in,
use_cache=True,
past_key_values=past_key_values,
)
past_key_values = model_output.past_key_values
logits = model_output.logits[:cur_beam_width, -1, :].float() / temperature
logprobs = torch.nn.functional.log_softmax(logits, dim=1)
phi_s_prime = phi_s[:, None] + logprobs
g_phi_s_prime = phi_s_prime - torch.log(-torch.log(torch.rand_like(logits)))
g_phi_s_prime = shift_gumbel(g_phi_s_prime, g_phi_s)
src = torch.arange(cur_beam_width, device=device).repeat_interleave(logits.shape[1])
y_prime = torch.arange(logits.shape[1], device=device).repeat(cur_beam_width)
g_phi_s, indices = torch.topk(g_phi_s_prime.flatten(), k=beam_width)
phi_s = phi_s_prime.flatten()[indices]
input_ids = torch.cat([input_ids[src[indices]], y_prime[indices, None]], dim=1)
cur_beam_width = g_phi_s.shape[0]
return input_ids
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment