Skip to content

Instantly share code, notes, and snippets.

@josiahbryan
Created April 15, 2024 04:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save josiahbryan/7a490834208dcd00534f471d9f40aac2 to your computer and use it in GitHub Desktop.
Save josiahbryan/7a490834208dcd00534f471d9f40aac2 to your computer and use it in GitHub Desktop.
/* eslint-disable no-unused-vars */
import Logger, { chalk } from 'shared/utils/Logger';
import { jsonSafeParse } from 'shared/utils/jsonSafeParse';
import { jsonSafeStringify } from 'shared/utils/jsonSafeStringify';
import { normalizeDecimals } from 'shared/utils/normalizeDecimals';
import { dedupBy } from 'shared/utils/dedupBy';
import { setProperty } from 'dot-prop';
import { annotateLogProbs } from './annotateLogProbs';
const propPathArrayToString = (pathArray) => {
// Instead of JUST joining "." we also replace '.[' with '['
// so that instead of: "foo.[0].bar.[1].baz" we get "foo[0].bar[1].baz
return pathArray.join('.').replace(/\.\[/g, '[');
};
/**
* @summary Helpers to parse JSON from token lists and annotate with log probabilities each key/value pair.
*
* Created a very blunt hand-roled JSON parser to brute-force parse an array of tokens from an LLM into javascript if it finds a '[' or a '{' in the token array.
*
* Only parses ABSOLUTELY VALID JSON, does not do anything fancy besides just accumulating JSON. However, it is very successful at getting JSON out of the LLMs without any sort of special tagging, code blocks, or anything - even if the JSON is surrounded by paragraphs of text, we find it.
*
* This serves two purposes:
* 1. It allows us to parse JSON from streaming responses without having to wait for the entire stream to land. This alone could thoretically abort the stream if an invalid JSON response was given or if one of the values was not acceptable, etc.
* 2. It allows us to get explicit logprobs/perplexity PER VALUE (and per key) in the JSON block - incredibly useful.
*
* The streaming json parsing ALSO supports annotating key/values with logprob/perplexity - uses the exact same code as non-streaming parsing. The stream handler below really is just an accumulator + onStreamChunk handler to accumulate chunks and just repeatedly brute-force parse them - so not TRUE stremaing where we discard processed chunks - we could add if needed, but for now, this serves our purposes.
*
* The logprob-per-key/value annotation also allows us to make either streaming circuit-breaks (abort stream on high-perplexity values, etc) or discard/regenerate the content if perplexity/probabilities on specific keys are out of range.
*
* All of this (streaming JSON, perplexity/logprob-per-key/value) is integerated at the root level in callChatModel.
*
* Streaming json works with both OpenAI and Claude, but the logprobs-per-key/value only work with OpenAI right now because Claude doesn't give us them. However, the brute-force method of parsing JSON is particularly useful with Claude and OpenAI both because it somewhat alleviates the need for predictTaggedJson - but that's to be determined in the future.
*
* The end result of all of this is that if there is JSON anywhere in the result from 'predict' (llmPredictFactory) and you inspect the metadata of the result, it will be automatically parsed (regardless of the model) and annotated with logprob/perplexity (with OpenAI) - and if you're using the streaming version, you can also get the same logprob/perplexity annotations on a per-key/value basis.
*
* (The stream wrapper below adds the .json prop to the root chunk with the FULL json parse from the start of the stream, so that's where you'd go to circuit-break mid-stream if needed - e.g. use an onStreamChunk(chunk) {} handler to check the `chunk.json.simplified` object for perplexity, etc).
*
* This is a very powerful feature that allows us to do some very interesting things with the LLMs.
*/
/**
* @function logprobsToAnnotatedJson
*
* Converts an array of logprobs to an annotated JSON object - implements
* VERY VERY basic, blunt-force parsing of JSON, very possibly may fail, YMMv.
*
* Returns { result, simplified } where result is rows of { path, key, value, keyMetadata, valueMetadata, at, final }
* and simplified is an array of "entries" like:
* ```json
[
[
'sentiment.negative',
{
key: 'negative',
value: 0.1,
keyProb: 0.999996,
valueProb: 0.999794,
finished: true
}
],
[
'recommendation',
{
key: 'recommendation',
value: "I'd recommend this to children who enjoy stories about animals.",
keyProb: 0.999999,
valueProb: 0.941542,
finished: true
}
]
]
* ```
* Which would correspond to a JSON object like:
* ```json
{
"sentiment": {
"negative": 0.1
},
"recommendation": "I'd recommend this to children who enjoy stories about animals."
}
* ```
*
* Which you use (`simplified` or `result`) depends on your use-case. Everything in 'simplified'
* is in 'result', just might take a bit more work to extract. I personally prefer 'simplified' for most uses.
*
* Note that 'finished' in the simplified output indicates that the value is complete, and is not a partial value,
* and it corresponds to the 'final' prop in the result output.
*
* This 'finished' / 'final' value is only needed for streaming, and is not needed for non-streaming use-cases,
* but it is included regardless. You can just ignore it if you know you're not getting streamed data,
* or if you want to only process values that are complete, you can filter on 'finished' (or 'final').
*
* @param {Object} options - The options for the conversion.
* @param {Array} options.logprobs - The array of logprobs to convert.
* @param {Object} [options.debug=false] - A flag indicating whether to output debug information.
* @param {Object} [options.dropUnparsed=false] - A flag indicating whether to drop unparsed values.
* @param {Object} [options.returnSimplifiedOnly=true] - A flag indicating whether to return simplified only, or return { result, simplified }
* @param {Object} [options.logger=Logger] - The logger object to use for logging.
* @returns {Object} The annotated JSON object.
*/
export function logprobsToAnnotatedJson({
logprobs,
logger = Logger,
debug,
dropUnparsed = false, // if no closing value found, don't bluntly push to end - off by default
returnSimplifiedOnly = false,
recursionState = {},
recursionMaxDepth = 10,
} = {}) {
// assertRequired({ logprobs }, 'logprobsToAnnotatedJson', logger);
if (!logprobs?.length) {
if (debug) {
logger.warn(`[logprobsToAnnotatedJson] No logprobs to parse`);
}
return { error: 'No logprobs to parse', simplified: [], results: [] };
}
const results = recursionState?.results || [];
const jsonPropPath = [];
let accum = {};
const resetAccum = () => {
accum = { list: [], key: null, value: undefined };
};
resetAccum();
const endsWithTerminatorToken = (array) => {
const lastToken = array[array.length - 1]?.token;
if (!lastToken) {
return false;
}
return (
lastToken?.includes('\n') ||
lastToken?.includes(':') ||
lastToken?.includes('"') ||
lastToken?.includes('}') ||
lastToken?.includes(']') ||
lastToken?.includes(',')
);
};
let isInsideKey = false;
let isInsideValue = false;
let isInsideString = false;
// let debug = false;
const keyValSepToken = '":';
const stringQuote = '"';
const lineTerm = '\n';
const objectOpen = '{';
const objectTerm = '}';
const arrayOpen = '[';
const arrayTerm = ']';
const bluntTerminateValue = ({ at = 'valueEnd' } = {}) => {
// Grab a simple probabilities list of the tokens for storage and annotation
// const valueLogProbs = accum.list.map((x) => ({
// token: x.token,
// logprob: x.logprob,
// }));
const valueLogProbs = dedupBy(accum.list, (x) => x.originalIndex).map(
(x) => ({
token: x.originalToken,
logprob: x.logprob,
}),
);
if (valueLogProbs[0]?.token?.includes('"')) {
valueLogProbs.shift();
}
if (endsWithTerminatorToken(valueLogProbs)) {
valueLogProbs.pop();
}
// Calculate perplexity, jointProb, and other stats
const annotatedProbs = annotateLogProbs(valueLogProbs);
// End of a value - accumulate the value into a single variable
let accumData = accum.list
?.map((x) => x.token)
.join('')
?.trim();
// Remove \n so we can try to parse json
if (accumData.endsWith('\n')) {
accumData = accumData.slice(0, -1);
}
// Remove closing comma if it's there so we can try to parse
if (accumData.endsWith(',')) {
accumData = accumData.slice(0, -1);
}
// Try to parse the data or just use the tokenized data if cannot parse
accum.value = jsonSafeParse(accumData) || accumData;
if (accum.value) {
if (isInsideKey) {
// streaming, fake end a key with NO VALUE
if (debug) {
logger.warn(`Streaming key end`, {
key: accum.value,
value: undefined,
currentPath: propPathArrayToString(jsonPropPath),
});
}
// Add to our results
results.push({
path: propPathArrayToString(jsonPropPath),
key: accum.value,
keyMetadata: {
...normalizeDecimals(annotatedProbs, { decimals: 6 }),
logprobs: valueLogProbs,
},
value: null,
valueMetadata: undefined,
at,
final: false,
});
} else {
if (debug) {
logger.warn(`EoValue`, {
key: accum.key,
value: accum.value,
currentPath: propPathArrayToString(jsonPropPath),
});
}
// Add to our results
results.push({
path: propPathArrayToString(jsonPropPath),
key: accum.key,
keyMetadata: accum.keyMetadata,
value: accum.value,
valueMetadata: {
...normalizeDecimals(annotatedProbs, { decimals: 6 }),
logprobs: valueLogProbs,
},
at,
final: at === 'valueEnd',
});
}
}
};
const reformatLogProbs = (array) => {
const fakeTokenSet = [];
array.forEach((rawToken, idx) => {
const { token, ...rest } = rawToken;
// Using 'Array.from' instead of 'split("")' to make everything is
// properly handled in character units (e.g. emojis, etc) This approach
// is highly effective for strings containing emojis, special Unicode symbols,
// and other complex characters that might span multiple UTF-16 code units.
const subTokens = Array.from(rawToken.token);
subTokens.forEach((char) => {
fakeTokenSet.push({
token: char,
...rest,
originalIndex: idx,
originalToken: token,
});
});
});
// if (debug) {logger.debug(`[reformatLogProbs]`, fakeTokenSet);}
return fakeTokenSet;
};
// Make a copy of the OUTER logprobs, but for recursion, we want to consume the same array
// so when the recursion is done, we can continue where we left off
const consumableProbs =
recursionState?.consumableProbs || reformatLogProbs(logprobs);
// Object.assign(recursionState, { consumableProbs });
const recursionIndent = chalk
.gray(' [.] ')
.repeat(recursionState.recursion || 0);
if (recursionState?.recursion > recursionMaxDepth) {
if (debug) {
logger.error(
`${recursionIndent} recursion depth exceeded, returning error`,
);
}
return {
error: `Recursion depth exceeded: ${recursionState.recursion}`,
simplified: [],
results: [],
};
}
// Find the first opening token of an object or an array before we start
let nextToken;
let tokenNum = recursionState?.tokenNum || -1; // for debugging
let objectType = recursionState?.type || null;
const ObjectTypes = {
Object: 'OBJECT',
Array: 'ARRAY',
};
const isObjTypeToken = (token) =>
token?.token?.includes(objectOpen) || token?.token?.includes(arrayOpen);
const getNextToken = () => consumableProbs[++tokenNum];
if (!objectType) {
nextToken = getNextToken();
if (!nextToken) {
if (debug) {
logger.error(
`${recursionIndent} no next token at start # ${tokenNum}, returning error`,
);
}
return {
error: `No next token at start #${tokenNum}`,
simplified: [],
results: [],
};
}
while (nextToken && !isObjTypeToken(nextToken)) {
if (debug) {
logger.info(
`${recursionIndent} searching for start, skipping token ${tokenNum}, '${nextToken?.token}'`,
);
}
nextToken = getNextToken();
}
if (!nextToken) {
if (debug) {
logger.error(
`${recursionIndent} never found start token, ended at # ${tokenNum}, returning error`,
);
}
return {
error: `Never found start token, ended at # ${tokenNum}`,
simplified: [],
results: [],
};
}
// // Decide which parsing mode to enact (object, e.g. key/val or array e.g. just values)
objectType = nextToken?.token?.includes(arrayOpen)
? ObjectTypes.Array
: ObjectTypes.Object;
}
if (debug) {
logger.info(
`${recursionIndent} START of new object type: ${chalk.bgCyanBright(
chalk.black(` ${objectType} `),
)} @ token ${tokenNum}`,
nextToken,
);
}
nextToken = getNextToken();
if (debug) {
logger.info(
`${recursionIndent} starting PARSING with token ${tokenNum}:`,
nextToken,
);
}
// if (Date.now()) {
// process.exit(0);
// }
const consumeChild = ({ type }) => {
// consumableProbs.unshift(nextToken); // put the object or array token back on the stack so we can consume it
const childData = logprobsToAnnotatedJson({
logprobs: consumableProbs,
logger,
debug,
dropUnparsed,
recursionState: {
type,
consumableProbs,
tokenNum, // > 0 ? tokenNum - 1 : tokenNum, // move back one token so child can consume the object or array token
recursion: (recursionState.recursion || 0) + 1,
},
});
tokenNum = childData?.tokenNum || tokenNum;
const arrayKey = propPathArrayToString(jsonPropPath);
if (debug) {
logger.warn(
`${recursionIndent} child results from array/object at ${arrayKey}:`,
childData,
);
}
if (childData.error) {
if (debug) {
logger.error(
`${recursionIndent} child error from array/object at ${arrayKey}:`,
childData.error,
);
}
return false;
}
const keyedResults = Object.entries(childData?.result || {})?.map(
([path, rest]) => ({
path: propPathArrayToString([arrayKey, path]), // add array index to path
...rest,
}),
);
// No need to bluntTerminateValue here;
results.push(...(keyedResults || []));
return true;
};
const renderParserState = ({
// eslint-disable-next-line no-shadow
recursionIndent,
// eslint-disable-next-line no-shadow
tokenNum,
data,
tokenStr,
// eslint-disable-next-line no-shadow
isInsideKey,
// eslint-disable-next-line no-shadow
isInsideValue,
// eslint-disable-next-line no-shadow
isInsideString,
isStringQuote,
isTerminator,
isObjectOpen,
isObjectTerm,
isArrayOpen,
isArrayTerm,
}) => {
let base = `${recursionIndent} [${objectType}#${tokenNum}] [${chalk.bgWhite(
chalk.black(data),
)}] (${chalk.bgYellow(chalk.black(jsonSafeStringify(tokenStr)))}) {`;
const flags = Object.entries({
isInsideKey,
isInsideValue,
isInsideString,
isStringQuote,
isTerminator,
isObjectOpen,
isObjectTerm,
isArrayOpen,
isArrayTerm,
currentPath: propPathArrayToString(jsonPropPath),
// isStringOpenToken,
// isKeyValSepToken,
// isLineTerm,
// isObjectTerm,
// isObjectOpen,
// isStringQuote,
}).map(([key, value]) => {
let cv = chalk.yellow(value);
if (value === true) {
cv = chalk.green(value);
} else if (value === false || value === undefined) {
cv = chalk.gray(value);
}
return `${chalk.gray(key)}: ${cv}`;
});
return `${base} ${flags.join(', ')} }`;
};
let previousTokenStr; // for detecting escaped quotes
// Array, so just consume a stream of values (simple or object or other arrays)
if (objectType === ObjectTypes.Array) {
let arrayIndex = 0;
jsonPropPath.push(`[${arrayIndex}]`);
while (nextToken) {
const { token: tokenStr } = nextToken;
const data = tokenStr?.trim();
const isLineTerm = tokenStr.includes(lineTerm);
const isObjectTerm = tokenStr.includes(objectTerm);
const isObjectOpen = tokenStr.includes(objectOpen);
const isArrayOpen = tokenStr.includes(arrayOpen);
const isArrayTerm = tokenStr.includes(arrayTerm);
const isComma = tokenStr.includes(',');
const isStringQuote =
tokenStr.includes(stringQuote) && !previousTokenStr?.includes('\\');
const isTerminator =
(isInsideValue && isStringQuote) ||
(!isInsideString &&
(isLineTerm || isObjectTerm || isArrayTerm || isComma));
// Store previous for escaped quote detection
previousTokenStr = tokenStr;
if (debug) {
logger.debug(
renderParserState({
recursionIndent,
tokenStr,
tokenNum,
data,
isInsideKey,
isInsideValue,
isInsideString,
isStringQuote,
isTerminator,
isObjectOpen,
isObjectTerm,
isArrayOpen,
isArrayTerm,
}),
);
}
if (!isInsideValue) {
if (isObjectOpen || isArrayOpen) {
// Recurse in and consume the child
if (
consumeChild({
type: isArrayOpen ? ObjectTypes.Array : ObjectTypes.Object,
})
) {
jsonPropPath.pop();
resetAccum();
arrayIndex++;
jsonPropPath.push(`[${arrayIndex}]`);
} else {
if (debug) {
logger.error(
`${recursionIndent} child error, breaking ARRAY loop`,
);
}
break;
}
} else if (data && !isArrayTerm && !isTerminator) {
// not an array/object, consume as a simple value and change values based on commas
// or array-end tokens or line-end tokens
isInsideValue = true; // values can be quoted strings, so if we have a key, then the first quote we find is a value
if (debug) {
logger.debug(
`${recursionIndent} > starting simple array value at index (${arrayIndex}) ...`,
);
}
isInsideString = !!isStringQuote;
accum.list = [];
}
} // If we hit a terminator while inside a value, then we are at the end of a value
else if (isInsideValue && isTerminator) {
// get any ending tokens
accum.list.push({ ...nextToken, data, type: 'data' });
isInsideValue = false;
isInsideString = false;
accum.key = arrayIndex;
// End of a value
bluntTerminateValue();
jsonPropPath.pop();
resetAccum();
arrayIndex++;
jsonPropPath.push(`[${arrayIndex}]`);
}
if (!isInsideValue && isArrayTerm) {
// Stop consuming tokens because we are at the end of the array
break;
}
// Accumulate data for whatever we're doing
accum.list.push({ ...nextToken, data, type: 'data' });
nextToken = getNextToken();
if (!nextToken) {
if (debug) {
logger.error(`${recursionIndent} no next token, breaking ARRAY loop`);
}
break;
}
}
} else {
let brokenCount;
while (nextToken) {
if (brokenCount !== undefined) {
if (tokenNum === brokenCount) {
if (debug) {
logger.error(
`${recursionIndent} breaking loop at ${tokenNum} because token num DID NOT CHANGE`,
);
}
break;
}
}
brokenCount = tokenNum;
const { token: tokenStr } = nextToken;
const data = tokenStr?.trim();
const isKeyValSepToken = data === keyValSepToken;
const isLineTerm = tokenStr.includes(lineTerm);
const isObjectTerm = tokenStr.includes(objectTerm);
const isObjectOpen = tokenStr.includes(objectOpen);
const isArrayOpen = tokenStr.includes(arrayOpen);
const isArrayTerm = tokenStr.includes(arrayTerm);
const isKeyValueSep = tokenStr.includes(':');
const isComma = tokenStr.includes(',');
const isStringQuote =
tokenStr.includes(stringQuote) && !previousTokenStr?.includes('\\');
// const isTerminator =
// (isStringQuote && (isInsideValue || isInsideKey)) ||
// isLineTerm ||
// isObjectTerm ||
// isArrayTerm;
const isTerminator =
((isInsideValue || isInsideKey) && isStringQuote) ||
(!isInsideString &&
(isLineTerm || isObjectTerm || isArrayTerm || isComma));
// Store previous for escaped quote detection
previousTokenStr = tokenStr;
if (debug) {
logger.debug(
renderParserState({
recursionIndent,
tokenStr,
tokenNum,
data,
isInsideKey,
isInsideValue,
isStringQuote,
isTerminator,
isObjectOpen,
isObjectTerm,
isArrayOpen,
isArrayTerm,
}),
);
}
// String start OR end
if (!isInsideKey && !isInsideValue) {
if (isObjectOpen || isArrayOpen) {
// Recurse in and consume the child
if (
!consumeChild({
type: isArrayOpen ? ObjectTypes.Array : ObjectTypes.Object,
})
) {
if (debug) {
logger.error(
`${recursionIndent} child error, breaking OBJECT loop`,
);
}
break;
}
jsonPropPath.pop(); // remove the last key from the path
}
// Only start a key if we haven't yet started a key or value
else if (!accum.key) {
if (isStringQuote) {
resetAccum();
isInsideKey = true; // keys ALWAYS are quoted in json, so first quote we find is a key
isInsideString = !!isStringQuote;
}
} else if (data && !isObjectTerm && !isKeyValueSep && !isTerminator) {
isInsideValue = true; // values can be quoted strings, so if we have a key, then the first quote we find is a value
isInsideString = !!isStringQuote;
if (debug) {
logger.debug(`${recursionIndent} > starting value (m1) ...`);
}
accum.list = [];
}
} else if (isInsideKey) {
if (isStringQuote || isKeyValSepToken) {
// get any ending tokens
accum.list.push({ ...nextToken, data, type: 'data' });
// isInside* only gets set once if we are NOT inside either, so we know this is the END of a KEY
let accumData = accum.list
?.map((x) => x.token)
.join('')
.trim();
// Try to parse the data or just use the tokenized data if cannot parse
accum.key = jsonSafeParse(accumData) || accumData;
jsonPropPath.push(accum.key); // add key to path
isInsideKey = false; // end the key accumulation
isInsideString = false;
const keyLogProbs = dedupBy(accum.list, (x) => x.originalIndex).map(
(x) => ({
token: x.originalToken,
logprob: x.logprob,
}),
);
if (keyLogProbs[0].token.includes('"')) {
keyLogProbs.shift();
}
// if (keyLogProbs[keyLogProbs.length - 1]?.token?.includes('"')) {
if (endsWithTerminatorToken(keyLogProbs)) {
keyLogProbs.pop();
}
const annotatedProbs = annotateLogProbs(keyLogProbs);
accum.keyMetadata = {
...normalizeDecimals(annotatedProbs, { decimals: 6 }),
logprobs: keyLogProbs,
};
if (debug) {
logger.warn(`${recursionIndent} KEY-END`, {
key: accum.key,
currentPath: propPathArrayToString(jsonPropPath),
});
}
accum.list = [];
}
}
// If we hit a terminator while inside a value, then we are at the end of a value
else if (isInsideValue && isTerminator) {
// get any ending tokens
if (!isObjectTerm && !isArrayTerm && !isComma) {
accum.list.push({ ...nextToken, data, type: 'data' });
}
isInsideValue = false;
isInsideString = false;
// End of a value
bluntTerminateValue();
jsonPropPath.pop();
resetAccum();
}
if (!isInsideValue && isObjectTerm) {
// just remove the last key from the path
jsonPropPath.pop();
resetAccum();
// End terminator for object
break;
}
// if (!isInsideValue || !isTerminator) {
// Accumulate data for whatever we're doing
accum.list.push({ ...nextToken, data, type: 'data' });
// nextToken = consumableProbs.shift();
// tokenNum++;
nextToken = getNextToken();
if (!nextToken) {
if (debug) {
logger.error(
`${recursionIndent} no next token, breaking OBJECT loop`,
);
}
break;
}
// if (tokenNum >= consumableProbs.length) {
// logger.warn(
// `${recursionIndent} breaking loop at token ${tokenNum} / ${consumableProbs.length}`,
// );
// // process.exit(0);
// break;
// }
}
}
if (/* isInsideKey || */ isInsideValue && !dropUnparsed) {
// Possible during streaming - we have not yet received value termination, so we need to handle this
bluntTerminateValue({ at: 'parseEnd' });
}
const object = Object.fromEntries(
results.map(({ path, ...rest }) => [path, rest]),
);
const simplified = results.map(
({ path, key, value, keyMetadata, valueMetadata, final }) => [
path,
{
// ...valueMetadata,
key,
value,
keyProb: keyMetadata?.jointProb,
valueProb: valueMetadata?.jointProb,
keyPerplexity: keyMetadata?.perplexity,
valuePerplexity: valueMetadata?.perplexity,
finished: !!final,
},
],
);
if (debug) {
logger.debug(
`${recursionIndent} [logprobsToAnnotatedJson] EXIT, simplified:`,
{
// results,
simplified,
},
);
}
return returnSimplifiedOnly
? simplified
: { result: object, simplified, tokenNum };
}
/**
* @function withStreamLogProbsToJson
*
* HOC function to converts a stream of log probabilities to JSON format if possible using `logprobsToAnnotatedJson`.
*
* Note that you MUST opt-in to parsing json by setting `enableStreamJsonParse` to `true` in the options object.
* This was added so we don't accidentally overburden stream parsing where the consumer of the prediction
* never needs the json. As the performance of the parser is somewhat unknown (it's not optimized at all),
* we don't want to just always parse json in case it's not needed.
*
* Adds a `.json` property to the chunk object from the LLM and calls your provided `onStreamChunk` method with the full chunk with the added JSON parse and logprob annotations.
*
* E.g. Inspect `chunk.json.simplified` in your chunk handler to get key/value perplexity/logprobs
*
*
* @param {Function} onStreamChunkInput - The callback function to handle each chunk of the stream.
* @param {Object} options - The options object.
* @param {boolean} [options.debug=false] - Whether to enable debug mode.
* @param {boolean} [options.debugParser=false] - Whether to enable debug mode for the parser.
* @param {Logger} [options.logger=Logger] - The logger object to use for logging.
* @param {boolean} [options.enableStreamJsonParse=false] - Whether to enable streaming JSON parsing.
* @param {boolean} [options.dropUnparsed=false] - Whether to drop unparsed log probabilities.
* @returns {Function} - The modified callback function to handle each chunk of the stream.
*/
export function withStreamLogProbsToJson(
onStreamChunkInput,
{
debug = false,
enableStreamJsonParse = false,
debugParser = false,
logger = Logger,
dropUnparsed = false,
},
) {
const logProbAccumulate = [];
const onStreamChunk = (chunk) => {
// logger.debug(
// `[stream]`,
// elideJsonMessage(JSON.stringify(chunk?.simple || chunk)),
// );
const chunkProbs = chunk?.simple?.logprobs?.content || [];
if (!chunkProbs.length && chunk?.simple?.content) {
// for non-logprob models, we still can attempt to do streaming json parsing
// by simulating a single logprob token
chunkProbs.push({
token: chunk.simple.content,
logprob: Math.log(1),
});
}
if (chunkProbs?.length) {
logProbAccumulate.push(...chunkProbs);
}
if (chunkProbs?.length && enableStreamJsonParse) {
/**
* Note that you MUST opt-in to parsing json by setting `enableStreamJsonParse` to `true` in the options object.
* This was added so we don't accidentally overburden stream parsing where the consumer of the prediction
* never needs the json. As the performance of the parser is somewhat unknown (it's not optimized at all),
* we don't want to just always parse json in case it's not needed.
*/
const { result, simplified } = logprobsToAnnotatedJson({
logprobs: logProbAccumulate,
logger,
debug: debugParser,
dropUnparsed,
});
// logger.debug(`[stream]`, elideJsonMessage(JSON.stringify(simplified)));
if (debug) {
logger.debug(`[stream]`, simplified);
}
onStreamChunkInput({
...chunk,
json: {
simplified,
content: result,
},
});
} else {
onStreamChunkInput(chunk);
}
// process.stdout.write(chunk?.simple?.content || '');
};
onStreamChunk.logProbAccumulate = logProbAccumulate;
return onStreamChunk;
}
/**
* Factory function that creates a debug handler for stream chunks to log the simplified JSON
* output to the logger (or write chunks to stdout if no chunk.json.simplified given)
*
* @param {Logger} logger - The logger instance to use for logging.
* @returns {Function} - The debug handler function.
*/
export function onStreamChunkDebugFactory(logger = Logger) {
return (chunk) => {
// process.stdout.write(chunk?.simple?.content || '');
if (chunk?.json?.simplified) {
logger.debug(
'[stream]\n',
// elideJsonMessage(chunk?.simple || chunk),
chunk?.json?.simplified
?.map(([key, value]) =>
chalk[value?.finished ? 'yellow' : 'green'](
' '.repeat(8) + [key, value?.value].join(': '),
),
)
.join('\n'),
);
} else {
process.stdout.write(chunk?.simple?.content || '');
}
};
}
export function convertSimplifiedJsonToObject(
simplified,
{ logger = Logger, debug = false, dualVersion = false } = {},
) {
const outputObject = {};
const outputWithMetadata = {};
simplified?.forEach(([key, value]) => {
setProperty(outputObject, key, value.value);
setProperty(outputWithMetadata, key, value);
});
if (dualVersion) {
return [outputObject, outputWithMetadata];
}
return outputObject;
}
const simpleMain = () => {
const example = {
scope: 'session',
operation: 'create_item',
name: 'user_info',
data: [
'{"name": "NVL Test Login", "email": "nlvp@example.com", "phoneNum": "+15551212224"}',
'application/json',
{ 'Content-Type': 'application/json' },
{ quotedMultiPart: true },
// [1, 2, 'three', [[[4]]], 5.6],
],
};
const json = JSON.stringify(example, null, 0);
// Using Array.from instead of split("") to handle unicode characters properly
// Add fake logprob to the fake token
logprobs = Array.from(json).map((token) => ({ token, logprob: Math.log(Math.random()) }));
const logger = Logger.getPrefixedCustomLogger('jsonParse');
const object = logprobsToAnnotatedJson({ logprobs, logger, debug: true });
logger.debug(
object?.simplified || object,
json,
convertSimplifiedJsonToObject(object?.simplified, { dualVersion: false }),
);
};
if (require.main === module) {
try {
simpleMain();
} catch (ex) {
Logger.error(ex);
}
process.exit(0);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment