Created
April 20, 2024 01:16
-
-
Save josiahbryan/f515db1283dd330dbf568051bfbac021 to your computer and use it in GitHub Desktop.
Convert an array of logprobs to perplexity and jointProb values
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
export const annotateLogProbs = (logprobs) => { | |
if (!logprobs?.length) { | |
// Fix people not deconstructing for us | |
if (logprobs?.content?.length) { | |
// eslint-disable-next-line no-param-reassign | |
logprobs = logprobs.content; | |
} else { | |
return {}; | |
} | |
} | |
// Not sure why, but some models give null probabilities even tho they generate tokens. Oh well, filter them out | |
const nonNullLogProbs = logprobs | |
.map((x) => x.logprob) | |
.filter((x) => x !== null); | |
// Sum for use below in joint probability calculation | |
const sumProbs = nonNullLogProbs.reduce((acc, logprob) => acc + logprob, 0); | |
/* | |
**The preferred way** to calculate the joint probability in cases involving log-probabilities is to: | |
1. Sum the log-probabilities. | |
2. Convert the sum back to a probability using the exponential function. | |
This method effectively computes the product of probabilities while working in the log space to maintain numerical stability. Mathematically, if you have log probabilities \( \log(p_1), \log(p_2), \ldots, \log(p_n) \), the joint probability \( P \) is given by: | |
\[ | |
P = \exp(\log(p_1) + \log(p_2) + \ldots + \log(p_n)) | |
\] | |
This approach is especially useful in language models and other sequential data models, where you're interested in the probability of observing the entire sequence as it occurred. | |
*/ | |
const jointProb = Math.exp(sumProbs); | |
// const perplexity = Math.exp(-1 * avg(probs)); | |
/* | |
From ChatGPT: | |
Perplexity is a measure used in language modeling to evaluate how well a probability model predicts a sample. It's defined as the exponentiation of the entropy of the distribution. In the context of language modeling, especially when working with log probabilities of sequences, perplexity can be computed using the formula: | |
\[ | |
\text{Perplexity} = 2^{-\frac{1}{N} \sum_{i=1}^N \log_2(p_i)} | |
\] | |
Where \( N \) is the number of tokens, and \( \log_2(p_i) \) are the log probabilities of each token (assuming the base is 2). However, if the log probabilities are provided in natural logarithm base (as in your array), you need to convert these to base 2 by dividing by \(\ln(2)\). | |
### Explanation of code below (sumOfLog2Probs and perplexity calculation): | |
- **`Math.log(2)`**: Converts the natural logarithm values to base 2. The natural logarithm base \( e \) needs to be converted because the definition of perplexity traditionally uses base 2. | |
- **`Math.pow(2, x)`**: Computes 2 raised to the power of \( x \), which is used here to calculate the perplexity from the average negative log2 probability per token. | |
- **`reduce()`**: Is used again to sum up all the converted log probabilities. | |
*/ | |
// Calculate the sum of log2 probabilities | |
const sumOfLog2Probs = nonNullLogProbs.reduce( | |
(acc, logprob) => acc + logprob / Math.log(2), | |
0, | |
); | |
// Calculate perplexity | |
const perplexity = 2 ** (-sumOfLog2Probs / nonNullLogProbs.length); | |
const annotations = { | |
jointProb, | |
perplexity, | |
// Convert the ORIGINAL array to probabilities so that the indices | |
// match the original input (nulls and all) for easier debugging | |
// (e.g. instead of using the nonNullLogProbs like we do above, which are calculations but don't | |
// have to match indices to the original input) | |
probabilities: logprobs.map((x) => | |
x.logprob === null | |
? null | |
: normalizeDecimals(Math.exp(x.logprob), { | |
decimals: 3, | |
}), | |
), | |
}; | |
return annotations; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment