Skip to content

Instantly share code, notes, and snippets.

@ed1d1a8d
Last active March 14, 2024 08:41
Show Gist options
  • Save ed1d1a8d/eb98e6ea47646589b2a8423bc9ac992e to your computer and use it in GitHub Desktop.
Save ed1d1a8d/eb98e6ea47646589b2a8423bc9ac992e to your computer and use it in GitHub Desktop.
import numpy as np
import openai
import scipy.special
import tiktoken
def get_top_chat_logprobs(
model: str,
messages: list[dict[str, str]],
seed: int = 42,
n_logprobs: int = 20,
) -> dict[int, tuple[float, str, str]]:
"""
Returns a dict mapping token_idx to (logprob, token_str, system_fingerprint)
for the top n_logprobs tokens.
Supports a maximum of 305 logprobs, but is flaky for 290+ logprobs
(i.e. 290-305 logprobs will sometimes work, sometimes error). 306 logprobs
and above will always error (unless OpenAI API changes).
"""
tokenizer = tiktoken.encoding_for_model(model)
assert 1 <= n_logprobs <= tokenizer.n_vocab
client = openai.Client()
def query_api(**kwargs):
return client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
max_tokens=1,
n=1,
logprobs=True,
top_logprobs=5,
seed=seed,
**kwargs,
)
base_resp = query_api()
# Maps token_idx bytes to (logprob, token_str).
logprob_dict: dict[int, tuple[float, str, str]] = {
tokenizer.encode_single_token(bytes(top_logprobs.bytes)): (
top_logprobs.logprob,
top_logprobs.token,
base_resp.system_fingerprint,
)
for top_logprobs in base_resp.choices[0]
.logprobs.content[0]
.top_logprobs
}
BIAS = -100
while len(logprob_dict) < n_logprobs:
log_masked_sum = scipy.special.logsumexp(
[logprob for logprob, _, _ in logprob_dict.values()]
)
unmasked_sum = -scipy.special.expm1(log_masked_sum)
log_unmasked_sum = np.log(unmasked_sum)
resp = query_api(
logit_bias={token_idx: BIAS for token_idx in logprob_dict.keys()}
)
for top_logprob in resp.choices[0].logprobs.content[0].top_logprobs:
if len(logprob_dict) >= n_logprobs:
break
token_str = top_logprob.token
if token_str in ["<|end|>", "<|endoftext|>"]:
token_idx = tokenizer.eot_token
else:
token_idx = tokenizer.encode_single_token(
bytes(top_logprob.bytes)
)
biased_logprob = top_logprob.logprob
true_logprob = biased_logprob + np.logaddexp(
log_masked_sum + BIAS, log_unmasked_sum
)
logprob_dict[token_idx] = (
true_logprob,
token_str,
resp.system_fingerprint,
)
print(len(logprob_dict))
return logprob_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment