Skip to content

Instantly share code, notes, and snippets.

@josiahbryan
Last active May 5, 2024 20:26
Show Gist options
  • Save josiahbryan/54dd184a9614882f1ee9ea413110d95e to your computer and use it in GitHub Desktop.
Save josiahbryan/54dd184a9614882f1ee9ea413110d95e to your computer and use it in GitHub Desktop.
/**
* Coax an LLM into generated grounded JSON based on key/value-level perplexity checks with a custom
* `yup` validation test for perplexity, model fallbacks, and a customizable failure injection callback
* to inject grounding instructions on request for exact paths (e.g. for specific keys, if perplexity too high).
*
* Contributing to the community as an example of how to coax an LLM into generating JSON that conforms
* to a schema with key-level perplexity checks and fallback models for retries.
*
* Write up at: https://dev.to/josiahbryan/tackling-json-perplexity-in-llm-outputs-a-weekend-project-jm8
*
* Public gist at: https://gist.github.com/josiahbryan/54dd184a9614882f1ee9ea413110d95e
*/
/*
# Example run
Example of a successful run of this script showing that it retried when perplexity was too high for the "formalName"
prop and finally succeeded. See the "executeMain" section at the bottom for the schema and prompt given.
```log
(PID 74676) 2024-04-15T04:17:35.977Z [DEBUG] EventBus.js:83: Creating EventBus 216bfe7c-34b7-4da0-a712-fa8f135149bd on channel: rubber-EventBus
(PID 74676) 2024-04-15T04:17:35.987Z coax-llm.js [DEBUG] script-cli.js:83: Running...
warning: setting timezone 'Etc/GMT+0' fails on server.
look at https://mariadb.com/kb/en/mysql_tzinfo_to_sql/ to load IANA timezone.
Setting timezone can be disabled with option `skipSetTimezone`
{
"formalName": "Emily",
"nickname": "Em",
"ageGuess": 25
}(PID 74676) 2024-04-15T04:17:36.875Z coax-llm.js [DEBUG] coax-llm.js:162: Cannot use generated JSON because it failed schema validation: Error: High perplexity detected in "Emily" at "formalName", indicating potential inaccuracies or lack of grounding in facts. Please refine to be more grounded in factual content and closely aligned with the context provided. Target perplexity under: 1.125. {
path: 'formalName',
type: 'perplexity',
errors: [
'Error: High perplexity detected in "Emily" at "formalName", indicating potential inaccuracies or lack of grounding in facts. Please refine to be more grounded in factual content and closely aligned with the context provided. Target perplexity under: 1.125.'
]
}
(PID 74676) 2024-04-15T04:17:36.875Z coax-llm.js [WARN] coax-llm.js:233: Retrying with fallback model: gpt-3.5-turbo-0125
{
"formalName": "Josiah Bryan",
"nickname": "Joey",
"ageGuess": 28
}(PID 74676) 2024-04-15T04:17:37.519Z coax-llm.js [INFO] coax-llm.js:330: Generated content: {
"formalName": "Josiah Bryan",
"nickname": "Joey",
"ageGuess": 28
}
(PID 74676) 2024-04-15T04:17:37.519Z coax-llm.js [INFO] coax-llm.js:334: Generated object: { formalName: 'Josiah Bryan', nickname: 'Joey', ageGuess: 28 }
(PID 74676) 2024-04-15T04:17:37.519Z coax-llm.js [INFO] coax-llm.js:338: Generated object with metadata: {
formalName: {
key: 'formalName',
value: 'Josiah Bryan',
keyProb: 0.999996,
valueProb: 0.999957,
keyPerplexity: 1.000001,
valuePerplexity: 1.000014,
finished: true
},
nickname: {
key: 'nickname',
value: 'Joey',
keyProb: 0.999996,
valueProb: 0.872926,
keyPerplexity: 1.000004,
valuePerplexity: 1.070314,
finished: true
},
ageGuess: {
key: 'ageGuess',
value: 28,
keyProb: 0.999994,
valueProb: 0.594872,
keyPerplexity: 1.000003,
valuePerplexity: 1.681035,
finished: true
}
}
```
*/
/* eslint-disable no-unused-vars */
import { ModelIds } from 'shared/utils/ChatBotDefEnums';
import Logger from 'shared/utils/Logger';
// import { yup, convertSchema } from 'shared/requests/utils/AuthoringHelpers';
import { readFile } from 'shared/utils/asyncFs';
import { executeMain } from '../../../../../utils/script-cli';
import { llmPredictFactory } from '../../../../chatbot/utils/llmPredictFactory';
import {
convertSimplifiedJsonToObject,
onStreamChunkDebugFactory,
} from '../../../../chatbot/utils/withStreamLogProbsToJson';
import {
safeEndLangfuse,
traceFactory,
} from '../../../../chatbot/utils/langfuse';
import { yup, convertSchema } from './yupPerplexityTest';
export const DEFAULT_COAX_MODEL_LIST = [
ModelIds.Gpt3_5Turbo, // retried 2x because it's the first due to incrementing logic
ModelIds.Gpt4, // most likely to work...
ModelIds.Claude_3_Opus, // no logprobs!
];
/**
* Coax an LLM into generating grounded JSON based on key/value-level perplexity checks
*
* @param {Object} options - The options for generating content.
* @param {string} options.prompt - The prompt for generating content.
* @param {string[]} [options.fallbackModels] - The fallback models to use if the generation fails.
* @param {yup.MixedSchema} [options.schema] - The schema to validate the generated JSON against.
* @param {Logger} [options.logger] - The logger to use for logging.
* @param {boolean} [options.includeSchemaInPrompt] - Whether to include the schema in the prompt.
* @param {Function} [options.failureInjectCallback] - The callback function to inject additional instructions on failure.
* @param {Function} [options.customFailureEvaluator] - The custom failure evaluator to handle more complex failure cases.
* @param {Object} [options.langfuseTrace] - The langfuse trace object.
* @param {string} [options.templateId] - The template ID.
* @param {string} [options.spanName] - The span name.
* @param {number} [options.maxRootPerplexity] - The maximum perplexity allowed for the root content.
* @param {boolean} [options.enableStreamJsonParse] - Whether to enable stream JSON parsing.
* @param {string} [options.cacheMode] - The cache mode to use.
* @returns {Promise<Object>} - The generated content and metadata.
*/
export async function coaxLlm({
prompt: promptInput,
modelList = DEFAULT_COAX_MODEL_LIST,
schema = yup.object().shape({}),
logger = Logger,
// If true, converts the schema prop to a JSON Schema and injects it at the END of the prompt
includeSchemaInPrompt = true,
// Async method that can be used to inject additional instructions on failure
failureInjectCallback = async ({
retryCount, // depth into retries
modelName, // model name the failure was generated with
message, // error message
type, // type of failure (e.g. perplexity, typeError, etc.)
errors, // array of errors if from yup
data, // data generated that caused a failure
content, // raw content from LLM if any
jsonLogProbs, // logprobs for the JSON generation (see simplified or result props)
}) =>
// should return an array of strings to inject into the prompt AFTER the failure message (failure message already automatically injected at start of prompt before original prompt, and these lines will go right after the failure message before the original prompt is added back in)
[],
// Custom failure evaluator to handle more complex failure cases that cannot be done in `yup`
customFailureEvaluator = async ({
// Parsed JSON object
object,
// Raw content generated from LLM
content,
// Metadata from the LLM response
metadata,
// Span in langfuse that wraps this evaluator
// eslint-disable-next-line no-shadow
langfuseTrace,
}) => {
// If no error:
// return <anything falsey: null/undefined/false>
// If error, return:
// return { type: "<type>", path: "0.whatever", message: "<some error message, given to the LLM>", }
},
langfuseTrace,
templateId = 'coax-llm',
spanName = templateId,
maxRootPerplexity = 3,
enableStreamJsonParse = true,
cacheMode = 'enabled',
debug = false,
onStreamChunk: onStreamChunkInput,
}) {
const prompt = [promptInput];
const span = langfuseTrace?.span({
name: spanName,
input: {
prompt: promptInput,
schema: convertSchema(schema),
includeSchemaInPrompt,
},
});
if (includeSchemaInPrompt) {
prompt.push(
'',
`### Schema\n\n`,
'```json\n',
JSON.stringify(convertSchema(schema), null, 2),
'```\n\n',
`Be sure to follow the schema above when generating JSON and output valid JSON that conforms to the schema:`,
);
}
const debugStream = onStreamChunkDebugFactory(logger);
const generate = async (args) => {
const { modelName, promptInject = [], retryCount = 0 } = args;
// Support passing in a function as predict instead of the model name
const predict =
typeof modelName === 'function'
? modelName
: llmPredictFactory({
modelName,
});
const { content, metadata } = await predict(
[...promptInject, ...prompt].join('\n'),
{
cacheMode,
// enableStreamJsonParse,
onStreamChunk: onStreamChunkInput || debugStream,
returnMetadata: true,
langfuseTrace: span,
heliconeCustomProps: {
templateId,
},
},
);
let failure;
let object;
let objectWithMetadata;
const { json: jsonLogProbs } = metadata || {};
const { simplified = [] } = jsonLogProbs || {};
// If schema, but no JSON extracted from content, generate a failure (since we cant even try to validate schema
// without JSON)
if (schema && !simplified.length) {
if (debug) {
logger.debug(
`Schema given, but no JSON extracted from content, cannot enforce schema, generating failure for schema validation`,
);
}
failure = {
message: 'Error: No JSON extracted from content',
type: 'noJson',
path: '$root',
data: content,
jsonLogProbs,
};
}
// If schema provided and no failure yet (e.g. has simplified json), first check that for
// failures from schema before checking further
if (!failure && schema) {
// Use our blunt-force JSON parser output that converted the token+logprobs into
// a simplified array of entries like ["path.to.key", { valuePerplexity: 1.2 }] into
// a set of two JSON objects - one with the original JSON and one with the metadata
// for each key in the JSON
[object, objectWithMetadata] = convertSimplifiedJsonToObject(simplified, {
dualVersion: true,
});
await schema
.validate(object, {
// Pass the simplified array of JSON entries to the context for the yup schema to use
// in the custom test for perplexity (above)
context: {
simplified,
logger,
},
})
.catch((ex) => {
const { path, type, errors } = ex;
// If our custom 'perplexity' test above, type will equal "perplexity" and errors will contain our custom
// error message from the test function
if (debug) {
logger.debug(
`Cannot use generated JSON because it failed schema validation: ${ex.message}`,
{
path,
type,
errors,
},
);
}
failure = {
message: `Error in generated JSON: ${ex.message}`,
type,
path,
errors,
data: object,
content,
jsonLogProbs,
};
});
}
// No failure yet? Check if there is no content generated
if (!failure && !content) {
failure = {
message: 'Error: No content generated',
type: 'noContent',
path: '$root',
data: null,
};
}
// No failure yet? Check the root perplexity
if (!failure && metadata?.perplexity > maxRootPerplexity) {
failure = {
message: `Error: High perplexity detected in root content, indicating potential inaccuracies or lack of grounding in facts. Please refine to be more grounded in factual content and closely aligned with the context provided. Target perplexity under: ${maxRootPerplexity}.`,
type: 'perplexity',
path: '$root',
data: content,
};
}
if (!failure) {
// No failure yet? Run any provided custom failure detection
const customFailure = await customFailureEvaluator({
object,
content,
metadata,
langfuseTrace: span,
});
if (customFailure) {
if (debug) {
logger.debug('Custom failure detected:', customFailure);
}
failure = customFailure;
}
}
// Handle generated failures by retrying up to a certain limit
if (failure) {
if (retryCount < modelList.length) {
const { type, path, message, errors, ...failureRest } = failure;
const inject = [
`## IMPORTANT`,
`Failure detected in previous attempt: ${message}`,
`Failure type: ${type}`,
`Failure path: ${path}`,
...(type === 'noJson' && failure?.data
? [
`Generate JSON from the previously generated content based on the schema given below: "${failure.data}"`,
'',
]
: []),
`Pay attention to the failure message above and the generation again.`,
'',
// Only run failure injection on the event of actual failures, obviously
...(await failureInjectCallback({
langfuseTrace: span,
modelName,
retryCount,
...failure,
})),
'', // Add a blank line before the original prompt
];
span?.event({
name: 'coax-failure',
metadata: {
type,
path,
errors,
simplifiedJsonLogProbs: failureRest.jsonLogProbs?.simplified,
content,
},
input: { type, path, errors },
output: {
retryCount,
inject,
...(typeof modelList[retryCount] !== 'function'
? { nextModel: modelList[retryCount] }
: {}),
},
level: 'ERROR',
statusMessage: failure.message,
});
if (debug) {
logger.warn(`Retrying with fallback model: ${modelList[retryCount]}`);
}
return generate({
modelName: modelList[retryCount],
// promptInject: `### Retry with ${modelName} failed. Trying with ${fallbackModels[retryCount]}`,
promptInject: inject,
retryCount: retryCount + 1,
});
}
return { failure };
}
return { content, object, objectWithMetadata };
};
return safeEndLangfuse(span, {
output: await generate({ modelName: modelList[0] }),
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment