Skip to content

Instantly share code, notes, and snippets.

@josiahbryan
Created April 20, 2024 01:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save josiahbryan/f515db1283dd330dbf568051bfbac021 to your computer and use it in GitHub Desktop.
Save josiahbryan/f515db1283dd330dbf568051bfbac021 to your computer and use it in GitHub Desktop.
Convert an array of logprobs to perplexity and jointProb values
export const annotateLogProbs = (logprobs) => {
if (!logprobs?.length) {
// Fix people not deconstructing for us
if (logprobs?.content?.length) {
// eslint-disable-next-line no-param-reassign
logprobs = logprobs.content;
} else {
return {};
}
}
// Not sure why, but some models give null probabilities even tho they generate tokens. Oh well, filter them out
const nonNullLogProbs = logprobs
.map((x) => x.logprob)
.filter((x) => x !== null);
// Sum for use below in joint probability calculation
const sumProbs = nonNullLogProbs.reduce((acc, logprob) => acc + logprob, 0);
/*
**The preferred way** to calculate the joint probability in cases involving log-probabilities is to:
1. Sum the log-probabilities.
2. Convert the sum back to a probability using the exponential function.
This method effectively computes the product of probabilities while working in the log space to maintain numerical stability. Mathematically, if you have log probabilities \( \log(p_1), \log(p_2), \ldots, \log(p_n) \), the joint probability \( P \) is given by:
\[
P = \exp(\log(p_1) + \log(p_2) + \ldots + \log(p_n))
\]
This approach is especially useful in language models and other sequential data models, where you're interested in the probability of observing the entire sequence as it occurred.
*/
const jointProb = Math.exp(sumProbs);
// const perplexity = Math.exp(-1 * avg(probs));
/*
From ChatGPT:
Perplexity is a measure used in language modeling to evaluate how well a probability model predicts a sample. It's defined as the exponentiation of the entropy of the distribution. In the context of language modeling, especially when working with log probabilities of sequences, perplexity can be computed using the formula:
\[
\text{Perplexity} = 2^{-\frac{1}{N} \sum_{i=1}^N \log_2(p_i)}
\]
Where \( N \) is the number of tokens, and \( \log_2(p_i) \) are the log probabilities of each token (assuming the base is 2). However, if the log probabilities are provided in natural logarithm base (as in your array), you need to convert these to base 2 by dividing by \(\ln(2)\).
### Explanation of code below (sumOfLog2Probs and perplexity calculation):
- **`Math.log(2)`**: Converts the natural logarithm values to base 2. The natural logarithm base \( e \) needs to be converted because the definition of perplexity traditionally uses base 2.
- **`Math.pow(2, x)`**: Computes 2 raised to the power of \( x \), which is used here to calculate the perplexity from the average negative log2 probability per token.
- **`reduce()`**: Is used again to sum up all the converted log probabilities.
*/
// Calculate the sum of log2 probabilities
const sumOfLog2Probs = nonNullLogProbs.reduce(
(acc, logprob) => acc + logprob / Math.log(2),
0,
);
// Calculate perplexity
const perplexity = 2 ** (-sumOfLog2Probs / nonNullLogProbs.length);
const annotations = {
jointProb,
perplexity,
// Convert the ORIGINAL array to probabilities so that the indices
// match the original input (nulls and all) for easier debugging
// (e.g. instead of using the nonNullLogProbs like we do above, which are calculations but don't
// have to match indices to the original input)
probabilities: logprobs.map((x) =>
x.logprob === null
? null
: normalizeDecimals(Math.exp(x.logprob), {
decimals: 3,
}),
),
};
return annotations;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment